Skip to content

Commit

Permalink
🧪 TESTS: SQLA Migrations -> pytest (#5192)
Browse files Browse the repository at this point in the history
This PR primarily converts the SQLA Migrations tests,
from using `AiidaTestCase` to using pytest fixtures,
and splits them out into multiple files
(to improve readability).

The tests are functionally equivalent,
just restructured and with a reduction in duplication when using `sqlalchemy.Session` contexts
(hence the reduction in lines).

One more test has been added (`tests/backends/aiida_sqlalchemy/migrations/test_all_basic.py`),
to cycle though all migration versions and try migrating them down&up from&to the latest revision
(against an empty database).
The list of versions is dynamically generated,
to ensure none are missed.

To ensure a fully empty database,
an option has been added to the test manager `aiida_profile.reset_db(with_user=False)` (defaults to `True`),
whereby the default User will not be auto-generated.
This was required for some migrations to succeed when migrating `down` then `up`,
and was how the current unittest implementation worked.
  • Loading branch information
chrisjsewell authored Dec 8, 2021
1 parent 053efee commit b806b7f
Show file tree
Hide file tree
Showing 18 changed files with 1,844 additions and 2,123 deletions.
12 changes: 6 additions & 6 deletions aiida/manage/tests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def use_profile(self, profile_name):
def has_profile_open(self):
return self._manager and self._manager.has_profile_open()

def reset_db(self):
return self._manager.reset_db()
def reset_db(self, with_user=True):
return self._manager.reset_db(with_user=with_user)

def destroy_all(self):
if self._manager:
Expand Down Expand Up @@ -166,20 +166,20 @@ def _select_db_test_case(self, backend):
self._test_case = SqlAlchemyTests()
self._test_case.test_session = get_scoped_session()

def reset_db(self):
def reset_db(self, with_user=True):
self._test_case.clean_db() # will drop all users
manager.reset_manager()
self.init_db()
self.init_db(with_user=with_user)

def init_db(self):
def init_db(self, with_user=True):
"""Initialise the database state for running of tests.
Adds default user if necessary.
"""
from aiida.cmdline.commands.cmd_user import set_default_user
from aiida.orm import User

if not User.objects.get_default():
if with_user and not User.objects.get_default():
user_dict = get_user_dict(_DEFAULT_PROFILE_INFO)
try:
user = User(**user_dict)
Expand Down
Empty file.
93 changes: 93 additions & 0 deletions tests/backends/aiida_sqlalchemy/migrations/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the migration engine (Alembic) as well as for the AiiDA migrations for SQLAlchemy."""
from contextlib import contextmanager
from typing import Iterator

import pytest
from sqlalchemy.orm import Session

from aiida.backends.sqlalchemy.manager import SqlaBackendManager


class Migrator:
"""A class to yield from the ``perform_migrations`` fixture."""

def __init__(self, backend, manager: SqlaBackendManager) -> None:
self.backend = backend
self._manager = manager

def migrate_up(self, revision: str) -> None:
"""Migrate up to a given revision."""
self._manager.migrate_up(revision)
if revision != 'head':
assert self._manager.get_schema_version_backend() == revision

def migrate_down(self, revision: str) -> None:
"""Migrate down to a given revision."""
self._manager.migrate_down(revision)
assert self._manager.get_schema_version_backend() == revision

def get_current_table(self, table_name):
"""
Return a Model instantiated at the correct migration.
Note that this is obtained by inspecting the database and not
by looking into the models file.
So, special methods possibly defined in the models files/classes are not present.
For instance, you can do::
DbGroup = self.get_current_table('db_dbgroup')
:param table_name: the name of the table.
"""
from alembic.migration import MigrationContext # pylint: disable=import-error
from sqlalchemy.ext.automap import automap_base # pylint: disable=import-error,no-name-in-module

with self.backend.get_session().bind.begin() as connection:
context = MigrationContext.configure(connection)
bind = context.bind

base = automap_base()
# reflect the tables
base.prepare(autoload_with=bind.engine)

return getattr(base.classes, table_name)

@contextmanager
def session(self) -> Iterator[Session]:
"""A context manager for a new session."""
with self.backend.get_session().bind.begin() as connection:
session = Session(connection.engine, future=True)
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()


@pytest.fixture()
def perform_migrations(aiida_profile, backend, request):
"""A fixture to setup the database for migration tests"""
# note downgrading to 1830c8430131 requires adding columns to `DbUser` and hangs if a user is present
aiida_profile.reset_db(with_user=False)
migrator = Migrator(backend, SqlaBackendManager())
marker = request.node.get_closest_marker('migrate_down')
if marker is not None:
assert marker.args, 'No version given'
migrator.migrate_down(marker.args[0])
yield migrator
# clear the database
# note this assumes the current schema contains the tables specified in `clean_db`
aiida_profile.reset_db(with_user=False)
# ensure that the database is migrated back up to the latest version, once finished
migrator.migrate_up('head')
102 changes: 102 additions & 0 deletions tests/backends/aiida_sqlalchemy/migrations/test_10_group_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for group migrations: 118349c10896 -> 0edcdd5a30f0"""
from .conftest import Migrator


def test_group_typestring(perform_migrations: Migrator):
"""Test the migration that renames the DbGroup type strings.
Verify that the type strings are properly migrated.
"""
# starting revision
perform_migrations.migrate_down('118349c10896') # 118349c10896_default_link_label.py

# setup the database
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name
DbUser = perform_migrations.get_current_table('db_dbuser') # pylint: disable=invalid-name
with perform_migrations.session() as session:
default_user = DbUser(email='[email protected]')
session.add(default_user)
session.commit()

# test user group type_string: 'user' -> 'core'
group_user = DbGroup(label='01', user_id=default_user.id, type_string='user')
session.add(group_user)
# test upf group type_string: 'data.upf' -> 'core.upf'
group_data_upf = DbGroup(label='02', user_id=default_user.id, type_string='data.upf')
session.add(group_data_upf)
# test auto.import group type_string: 'auto.import' -> 'core.import'
group_autoimport = DbGroup(label='03', user_id=default_user.id, type_string='auto.import')
session.add(group_autoimport)
# test auto.run group type_string: 'auto.run' -> 'core.auto'
group_autorun = DbGroup(label='04', user_id=default_user.id, type_string='auto.run')
session.add(group_autorun)

session.commit()

# Store values for later tests
group_user_pk = group_user.id
group_data_upf_pk = group_data_upf.id
group_autoimport_pk = group_autoimport.id
group_autorun_pk = group_autorun.id

# migrate up
perform_migrations.migrate_up('bf591f31dd12') # bf591f31dd12_dbgroup_type_string.py

# perform some checks
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name
with perform_migrations.session() as session:
group_user = session.query(DbGroup).filter(DbGroup.id == group_user_pk).one()
assert group_user.type_string == 'core'

# test upf group type_string: 'data.upf' -> 'core.upf'
group_data_upf = session.query(DbGroup).filter(DbGroup.id == group_data_upf_pk).one()
assert group_data_upf.type_string == 'core.upf'

# test auto.import group type_string: 'auto.import' -> 'core.import'
group_autoimport = session.query(DbGroup).filter(DbGroup.id == group_autoimport_pk).one()
assert group_autoimport.type_string == 'core.import'

# test auto.run group type_string: 'auto.run' -> 'core.auto'
group_autorun = session.query(DbGroup).filter(DbGroup.id == group_autorun_pk).one()
assert group_autorun.type_string == 'core.auto'


def test_group_extras(perform_migrations: Migrator):
"""Test migration to add the `extras` JSONB column to the `DbGroup` model.
Verify that the model now has an extras column with empty dictionary as default.
"""
# starting revision
perform_migrations.migrate_down('bf591f31dd12') # bf591f31dd12_dbgroup_type_string.py

# setup the database
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name
DbUser = perform_migrations.get_current_table('db_dbuser') # pylint: disable=invalid-name
with perform_migrations.session() as session:
default_user = DbUser(email='[email protected]')
session.add(default_user)
session.commit()

group = DbGroup(label='01', user_id=default_user.id, type_string='user')
session.add(group)
session.commit()

group_pk = group.id

# migrate up
perform_migrations.migrate_up('0edcdd5a30f0') # 0edcdd5a30f0_dbgroup_extras.py

# perform some checks
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name
with perform_migrations.session() as session:
group = session.query(DbGroup).filter(DbGroup.id == group_pk).one()
assert group.extras == {}
Loading

0 comments on commit b806b7f

Please sign in to comment.