diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index b52de6e..7d0d291 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -26,48 +26,21 @@ In the source database, remove the circular foreign key between `concept` and `v alter table concept drop constraint concept.concept_vocabulary_id_fkey +and between `concept` and `domain` with, for example: + +.. code-block:: sql + + alter table concept drop constraint concept.concept_domain_id_fkey + + Create a config file ++++++++++++++++++++ Make a config file called `omop.yaml`. At the very least, our config file will need to specify the tables that need to be copied over in their entirety: -.. code-block:: yaml - - tables: - # Standardized Vocabularies - concept: - vocabulary_table: true - concept_class - vocabulary_table: true - concept_relationship: - vocabulary_table: true - concept_synonym: - vocabulary_table: true - domain: - vocabulary_table: true - drug_strength: - vocabulary_table: true - cohort_definition: - vocabulary_table: true - attribute_definition: - vocabulary_table: true - relationship: - vocabulary_table: true - source_to_concept_map - vocabulary_table: true - vocabulary: - vocabulary_table: true - # Standardized meta-data - cdm_source: - vocabulary_table: true - # Standardized health system data - location: - vocabulary_table: true - care_site: - vocabulary_table: true - provider: - vocabulary_table: true +.. literalinclude:: ../../tests/examples/omop/config.yaml + :language: yaml Make SQLAlchemy file ++++++++++++++++++++ diff --git a/sqlsynthgen/base.py b/sqlsynthgen/base.py index 3f247e2..8a5c4a5 100644 --- a/sqlsynthgen/base.py +++ b/sqlsynthgen/base.py @@ -19,5 +19,7 @@ def load(self, connection: Any) -> None: "r", newline="", encoding="utf-8" ) as yamlfile: rows = yaml.load(yamlfile, Loader=yaml.Loader) + if not rows: + return stmt = insert(self.table).values(list(rows)) connection.execute(stmt) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 67473bd..18fa03d 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -2,7 +2,7 @@ import inspect from sys import stderr from types import ModuleType -from typing import Any, Final, Optional +from typing import Any, Final, Optional, Tuple import snsql from mimesis.providers.base import BaseProvider @@ -48,12 +48,20 @@ } -def _orm_class_from_table_name(tables_module: Any, full_name: str) -> Optional[Any]: +def _orm_class_from_table_name( + tables_module: Any, full_name: str +) -> Optional[Tuple[str, str]]: """Return the ORM class corresponding to a table name.""" + # If the class in tables_module is an SQLAlchemy ORM class for mapper in tables_module.Base.registry.mappers: cls = mapper.class_ if cls.__table__.fullname == full_name: - return cls + return cls.__name__, cls.__name__ + ".__table__" + + # If the class in tables_module is a SQLAlchemy Core Table + guess = "t_" + full_name + if guess in dir(tables_module): + return guess, guess return None @@ -100,13 +108,16 @@ def _add_default_generator(content: str, tables_module: ModuleType, column: Any) target_name_parts = fkey.target_fullname.split(".") target_table_name = ".".join(target_name_parts[:-1]) target_column_name = target_name_parts[-1] - target_orm_class = _orm_class_from_table_name(tables_module, target_table_name) - if target_orm_class is None: + class_and_name = _orm_class_from_table_name(tables_module, target_table_name) + if not class_and_name: raise ValueError(f"Could not find the ORM class for {target_table_name}.") + + target_orm_class, _ = class_and_name + content += ( f"self.{column.name} = " f"generic.column_value_provider.column_value(dst_db_conn, " - f"{tables_module.__name__}.{target_orm_class.__name__}, " + f"{tables_module.__name__}.{target_orm_class}, " f'"{target_column_name}"' ")" ) @@ -180,13 +191,18 @@ def make_generators_from_tables( if table_config.get("vocabulary_table") is True: - orm_class = _orm_class_from_table_name(tables_module, table.fullname) - if not orm_class: + class_and_name = _orm_class_from_table_name(tables_module, table.fullname) + + if not class_and_name: raise RuntimeError(f"Couldn't find {table.fullname} in {tables_module}") - class_name = orm_class.__name__ + + class_name, table_name = class_and_name + + the_table_to_download = f"{tables_module.__name__}.{table_name}" + new_content += ( f"\n\n{class_name.lower()}_vocab " - f"= FileUploader({tables_module.__name__}.{class_name}.__table__)" + f"= FileUploader({the_table_to_download})" ) vocab_dict += f'{INDENTATION}"{table.name}": {class_name.lower()}_vocab,\n' diff --git a/sqlsynthgen/utils.py b/sqlsynthgen/utils.py index 850b6b0..93abd61 100644 --- a/sqlsynthgen/utils.py +++ b/sqlsynthgen/utils.py @@ -47,8 +47,9 @@ def download_table(table: Any, engine: Any) -> None: yaml_file_name = table.fullname + ".yaml" yaml_file_path = Path(yaml_file_name) if yaml_file_path.exists(): - print(f"{str(yaml_file_name)} already exists. Exiting...", file=stderr) - sys.exit(1) + # print(f"{str(yaml_file_name)} already exists. Exiting...", file=stderr) + # sys.exit(1) + print(f"Warning: {str(yaml_file_name)} already exists.", file=stderr) stmt = select([table]) with engine.connect() as conn: diff --git a/tests/examples/omop/config.yaml b/tests/examples/omop/config.yaml index e69de29..73c706f 100644 --- a/tests/examples/omop/config.yaml +++ b/tests/examples/omop/config.yaml @@ -0,0 +1,36 @@ +tables: + # Standardized Vocabularies + concept: + vocabulary_table: true + concept_ancestor: + vocabulary_table: true + concept_class: + vocabulary_table: true + concept_relationship: + vocabulary_table: true + concept_synonym: + vocabulary_table: true + domain: + vocabulary_table: true + drug_strength: + vocabulary_table: true + cohort_definition: + vocabulary_table: true + attribute_definition: + vocabulary_table: true + relationship: + vocabulary_table: true + source_to_concept_map: + vocabulary_table: true + vocabulary: + vocabulary_table: true + # Standardized meta-data + cdm_source: + vocabulary_table: true + # Standardized health system data + location: + vocabulary_table: true + care_site: + vocabulary_table: true + provider: + vocabulary_table: true diff --git a/tests/test_make.py b/tests/test_make.py index 1814ff6..b43d45a 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -55,6 +55,30 @@ def test_make_generators_from_tables( self.assertEqual(expected, actual) + @patch("sqlsynthgen.make.get_settings") + @patch("sqlsynthgen.make.create_engine") + @patch("sqlsynthgen.make.download_table") + def test_make_generators_from_table( + self, + mock_download: MagicMock, + mock_create: MagicMock, + mock_get_settings: MagicMock, + ) -> None: + """Check that we can make a generators file from a tables module.""" + mock_get_settings.return_value = get_test_settings() + with open("expected_ssg.py", encoding="utf-8") as expected_output: + expected = expected_output.read() + conf_path = "example_config.yaml" + with open(conf_path, "r", encoding="utf8") as f: + config = yaml.safe_load(f) + stats_path = "example_stats.yaml" + + actual = make_generators_from_tables(example_orm, config, stats_path) + mock_download.assert_called_once() + mock_create.assert_called_once() + + self.assertEqual(expected, actual) + class TestMakeTables(SSGTestCase): """Test the make_tables function."""