-
Notifications
You must be signed in to change notification settings - Fork 208
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🧪 TESTS: SQLA Migrations -> pytest (#5192)
This PR primarily converts the SQLA Migrations tests, from using `AiidaTestCase` to using pytest fixtures, and splits them out into multiple files (to improve readability). The tests are functionally equivalent, just restructured and with a reduction in duplication when using `sqlalchemy.Session` contexts (hence the reduction in lines). One more test has been added (`tests/backends/aiida_sqlalchemy/migrations/test_all_basic.py`), to cycle though all migration versions and try migrating them down&up from&to the latest revision (against an empty database). The list of versions is dynamically generated, to ensure none are missed. To ensure a fully empty database, an option has been added to the test manager `aiida_profile.reset_db(with_user=False)` (defaults to `True`), whereby the default User will not be auto-generated. This was required for some migrations to succeed when migrating `down` then `up`, and was how the current unittest implementation worked.
- Loading branch information
1 parent
053efee
commit b806b7f
Showing
18 changed files
with
1,844 additions
and
2,123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# -*- coding: utf-8 -*- | ||
########################################################################### | ||
# Copyright (c), The AiiDA team. All rights reserved. # | ||
# This file is part of the AiiDA code. # | ||
# # | ||
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # | ||
# For further information on the license, see the LICENSE.txt file # | ||
# For further information please visit http://www.aiida.net # | ||
########################################################################### | ||
"""Tests for the migration engine (Alembic) as well as for the AiiDA migrations for SQLAlchemy.""" | ||
from contextlib import contextmanager | ||
from typing import Iterator | ||
|
||
import pytest | ||
from sqlalchemy.orm import Session | ||
|
||
from aiida.backends.sqlalchemy.manager import SqlaBackendManager | ||
|
||
|
||
class Migrator: | ||
"""A class to yield from the ``perform_migrations`` fixture.""" | ||
|
||
def __init__(self, backend, manager: SqlaBackendManager) -> None: | ||
self.backend = backend | ||
self._manager = manager | ||
|
||
def migrate_up(self, revision: str) -> None: | ||
"""Migrate up to a given revision.""" | ||
self._manager.migrate_up(revision) | ||
if revision != 'head': | ||
assert self._manager.get_schema_version_backend() == revision | ||
|
||
def migrate_down(self, revision: str) -> None: | ||
"""Migrate down to a given revision.""" | ||
self._manager.migrate_down(revision) | ||
assert self._manager.get_schema_version_backend() == revision | ||
|
||
def get_current_table(self, table_name): | ||
""" | ||
Return a Model instantiated at the correct migration. | ||
Note that this is obtained by inspecting the database and not | ||
by looking into the models file. | ||
So, special methods possibly defined in the models files/classes are not present. | ||
For instance, you can do:: | ||
DbGroup = self.get_current_table('db_dbgroup') | ||
:param table_name: the name of the table. | ||
""" | ||
from alembic.migration import MigrationContext # pylint: disable=import-error | ||
from sqlalchemy.ext.automap import automap_base # pylint: disable=import-error,no-name-in-module | ||
|
||
with self.backend.get_session().bind.begin() as connection: | ||
context = MigrationContext.configure(connection) | ||
bind = context.bind | ||
|
||
base = automap_base() | ||
# reflect the tables | ||
base.prepare(autoload_with=bind.engine) | ||
|
||
return getattr(base.classes, table_name) | ||
|
||
@contextmanager | ||
def session(self) -> Iterator[Session]: | ||
"""A context manager for a new session.""" | ||
with self.backend.get_session().bind.begin() as connection: | ||
session = Session(connection.engine, future=True) | ||
try: | ||
yield session | ||
except Exception: | ||
session.rollback() | ||
raise | ||
finally: | ||
session.close() | ||
|
||
|
||
@pytest.fixture() | ||
def perform_migrations(aiida_profile, backend, request): | ||
"""A fixture to setup the database for migration tests""" | ||
# note downgrading to 1830c8430131 requires adding columns to `DbUser` and hangs if a user is present | ||
aiida_profile.reset_db(with_user=False) | ||
migrator = Migrator(backend, SqlaBackendManager()) | ||
marker = request.node.get_closest_marker('migrate_down') | ||
if marker is not None: | ||
assert marker.args, 'No version given' | ||
migrator.migrate_down(marker.args[0]) | ||
yield migrator | ||
# clear the database | ||
# note this assumes the current schema contains the tables specified in `clean_db` | ||
aiida_profile.reset_db(with_user=False) | ||
# ensure that the database is migrated back up to the latest version, once finished | ||
migrator.migrate_up('head') |
102 changes: 102 additions & 0 deletions
102
tests/backends/aiida_sqlalchemy/migrations/test_10_group_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# -*- coding: utf-8 -*- | ||
########################################################################### | ||
# Copyright (c), The AiiDA team. All rights reserved. # | ||
# This file is part of the AiiDA code. # | ||
# # | ||
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # | ||
# For further information on the license, see the LICENSE.txt file # | ||
# For further information please visit http://www.aiida.net # | ||
########################################################################### | ||
"""Tests for group migrations: 118349c10896 -> 0edcdd5a30f0""" | ||
from .conftest import Migrator | ||
|
||
|
||
def test_group_typestring(perform_migrations: Migrator): | ||
"""Test the migration that renames the DbGroup type strings. | ||
Verify that the type strings are properly migrated. | ||
""" | ||
# starting revision | ||
perform_migrations.migrate_down('118349c10896') # 118349c10896_default_link_label.py | ||
|
||
# setup the database | ||
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name | ||
DbUser = perform_migrations.get_current_table('db_dbuser') # pylint: disable=invalid-name | ||
with perform_migrations.session() as session: | ||
default_user = DbUser(email='[email protected]') | ||
session.add(default_user) | ||
session.commit() | ||
|
||
# test user group type_string: 'user' -> 'core' | ||
group_user = DbGroup(label='01', user_id=default_user.id, type_string='user') | ||
session.add(group_user) | ||
# test upf group type_string: 'data.upf' -> 'core.upf' | ||
group_data_upf = DbGroup(label='02', user_id=default_user.id, type_string='data.upf') | ||
session.add(group_data_upf) | ||
# test auto.import group type_string: 'auto.import' -> 'core.import' | ||
group_autoimport = DbGroup(label='03', user_id=default_user.id, type_string='auto.import') | ||
session.add(group_autoimport) | ||
# test auto.run group type_string: 'auto.run' -> 'core.auto' | ||
group_autorun = DbGroup(label='04', user_id=default_user.id, type_string='auto.run') | ||
session.add(group_autorun) | ||
|
||
session.commit() | ||
|
||
# Store values for later tests | ||
group_user_pk = group_user.id | ||
group_data_upf_pk = group_data_upf.id | ||
group_autoimport_pk = group_autoimport.id | ||
group_autorun_pk = group_autorun.id | ||
|
||
# migrate up | ||
perform_migrations.migrate_up('bf591f31dd12') # bf591f31dd12_dbgroup_type_string.py | ||
|
||
# perform some checks | ||
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name | ||
with perform_migrations.session() as session: | ||
group_user = session.query(DbGroup).filter(DbGroup.id == group_user_pk).one() | ||
assert group_user.type_string == 'core' | ||
|
||
# test upf group type_string: 'data.upf' -> 'core.upf' | ||
group_data_upf = session.query(DbGroup).filter(DbGroup.id == group_data_upf_pk).one() | ||
assert group_data_upf.type_string == 'core.upf' | ||
|
||
# test auto.import group type_string: 'auto.import' -> 'core.import' | ||
group_autoimport = session.query(DbGroup).filter(DbGroup.id == group_autoimport_pk).one() | ||
assert group_autoimport.type_string == 'core.import' | ||
|
||
# test auto.run group type_string: 'auto.run' -> 'core.auto' | ||
group_autorun = session.query(DbGroup).filter(DbGroup.id == group_autorun_pk).one() | ||
assert group_autorun.type_string == 'core.auto' | ||
|
||
|
||
def test_group_extras(perform_migrations: Migrator): | ||
"""Test migration to add the `extras` JSONB column to the `DbGroup` model. | ||
Verify that the model now has an extras column with empty dictionary as default. | ||
""" | ||
# starting revision | ||
perform_migrations.migrate_down('bf591f31dd12') # bf591f31dd12_dbgroup_type_string.py | ||
|
||
# setup the database | ||
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name | ||
DbUser = perform_migrations.get_current_table('db_dbuser') # pylint: disable=invalid-name | ||
with perform_migrations.session() as session: | ||
default_user = DbUser(email='[email protected]') | ||
session.add(default_user) | ||
session.commit() | ||
|
||
group = DbGroup(label='01', user_id=default_user.id, type_string='user') | ||
session.add(group) | ||
session.commit() | ||
|
||
group_pk = group.id | ||
|
||
# migrate up | ||
perform_migrations.migrate_up('0edcdd5a30f0') # 0edcdd5a30f0_dbgroup_extras.py | ||
|
||
# perform some checks | ||
DbGroup = perform_migrations.get_current_table('db_dbgroup') # pylint: disable=invalid-name | ||
with perform_migrations.session() as session: | ||
group = session.query(DbGroup).filter(DbGroup.id == group_pk).one() | ||
assert group.extras == {} |
Oops, something went wrong.