diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 6e7d9c2..02bd224 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -26,18 +26,23 @@ def create_db_tables(metadata: Any) -> Any: def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: """Connect to a database and populate it with data.""" settings = get_settings() - engine = create_engine(settings.dst_postgres_dsn) + dst_engine = create_engine(settings.dst_postgres_dsn) + src_engine = create_engine(settings.src_postgres_dsn) - with engine.connect() as conn: - populate(conn, sorted_tables, sorted_generators, num_rows) + with dst_engine.connect() as dst_conn: + with src_engine.connect() as src_conn: + populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_rows) -def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None: +def populate( + src_conn: Any, dst_conn: Any, tables: list, generators: list, num_rows: int +) -> None: """Populate a database schema with dummy data.""" - for table, generator in zip(tables, generators): - # Run all the inserts for one table in a transaction - with conn.begin(): + for table, generator in zip( + tables, generators + ): # Run all the inserts for one table in a transaction + with dst_conn.begin(): for _ in range(num_rows): - stmt = insert(table).values(generator(conn).__dict__) - conn.execute(stmt) + stmt = insert(table).values(generator(src_conn, dst_conn).__dict__) + dst_conn.execute(stmt) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index cff43c9..ebe4694 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -54,7 +54,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: + new_class_name + ":\n" + INDENTATION - + "def __init__(self, db_connection):\n" + + "def __init__(self, src_db_conn, dst_db_conn):\n" ) for column in table.columns: @@ -70,7 +70,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: fk_schema, fk_table, fk_column = fk_column_path.split(".") new_content += ( f"{INDENTATION*2}self.{column.name} = " - f"generic.column_value_provider.column_value(db_connection, " + f"generic.column_value_provider.column_value(dst_db_conn, " f'"{fk_schema}", "{fk_table}", "{fk_column}"' ")\n" ) diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 10a7168..a906915 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -9,7 +9,7 @@ class personGenerator: - def __init__(self, db_connection): + def __init__(self, src_db_conn, dst_db_conn): self.name = generic.text.color() self.nhs_number = generic.text.color() self.research_opt_out = generic.development.boolean() @@ -18,8 +18,8 @@ def __init__(self, db_connection): class hospital_visitGenerator: - def __init__(self, db_connection): - self.person_id = generic.column_value_provider.column_value(db_connection, "myschema", "person", "person_id") + def __init__(self, src_db_conn, dst_db_conn): + self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id") self.visit_start = generic.datetime.datetime() self.visit_end = generic.datetime.date() self.visit_duration_seconds = generic.numeric.float_number() diff --git a/tests/test_create.py b/tests/test_create.py index de1b163..3041b8b 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,7 +2,7 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from sqlsynthgen.create import create_db_data, create_db_tables +from sqlsynthgen.create import create_db_data, create_db_tables, populate from tests.utils import get_test_settings @@ -21,7 +21,7 @@ def test_create_db_data(self) -> None: create_db_data([], [], 0) mock_populate.assert_called_once() - mock_create_engine.assert_called_once() + mock_create_engine.assert_called() def test_create_db_tables(self) -> None: """Test the create_tables function.""" @@ -36,3 +36,21 @@ def test_create_db_tables(self) -> None: mock_create_engine.assert_called_once_with( mock_get_settings.return_value.dst_postgres_dsn ) + + def test_populate(self) -> None: + """Test the populate function.""" + with patch("sqlsynthgen.create.insert") as mock_insert: + mock_src_conn = MagicMock() + mock_dst_conn = MagicMock() + mock_gen = MagicMock() + tables = [None] + generators = [mock_gen] + populate(mock_src_conn, mock_dst_conn, tables, generators, 1) + + mock_gen.assert_called_once_with(mock_src_conn, mock_dst_conn) + mock_insert.return_value.values.assert_called_once_with( + mock_gen.return_value.__dict__ + ) + mock_dst_conn.execute.assert_called_once_with( + mock_insert.return_value.values.return_value + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 33de3a9..2f4a83c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -53,7 +53,7 @@ def test_workflow(self) -> None: run(["sqlsynthgen", "create-tables", self.orm_file_path], env=env, check=True) run( - ["sqlsynthgen", "create-data", self.orm_file_path, self.ssg_file_path], + ["sqlsynthgen", "create-data", self.orm_file_path, self.ssg_file_path, "1"], env=env, check=True, ) diff --git a/tests/test_providers.py b/tests/test_providers.py index fb37af4..fe79fbb 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -30,7 +30,8 @@ class BinaryProviderTestCase(TestCase): """Tests for the BytesProvider class.""" def test_bytes(self) -> None: - BytesProvider().bytes().decode("utf-8") + """Test the bytes method.""" + self.assertTrue(BytesProvider().bytes().decode("utf-8") != "") @skipUnless(