Skip to content

Commit

Permalink
Pass a src db connection to the generators
Browse files Browse the repository at this point in the history
In addition to the destination connection they already have
  • Loading branch information
Iain-S committed Jan 23, 2023
1 parent 21453ec commit 1ed1542
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 18 deletions.
23 changes: 14 additions & 9 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,23 @@ def create_db_tables(metadata: Any) -> Any:
def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None:
"""Connect to a database and populate it with data."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)
dst_engine = create_engine(settings.dst_postgres_dsn)
src_engine = create_engine(settings.src_postgres_dsn)

with engine.connect() as conn:
populate(conn, sorted_tables, sorted_generators, num_rows)
with dst_engine.connect() as dst_conn:
with src_engine.connect() as src_conn:
populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_rows)


def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None:
def populate(
src_conn: Any, dst_conn: Any, tables: list, generators: list, num_rows: int
) -> None:
"""Populate a database schema with dummy data."""

for table, generator in zip(tables, generators):
# Run all the inserts for one table in a transaction
with conn.begin():
for table, generator in zip(
tables, generators
): # Run all the inserts for one table in a transaction
with dst_conn.begin():
for _ in range(num_rows):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
stmt = insert(table).values(generator(src_conn, dst_conn).__dict__)
dst_conn.execute(stmt)
4 changes: 2 additions & 2 deletions sqlsynthgen/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str:
+ new_class_name
+ ":\n"
+ INDENTATION
+ "def __init__(self, db_connection):\n"
+ "def __init__(self, src_db_conn, dst_db_conn):\n"
)

for column in table.columns:
Expand All @@ -72,7 +72,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str:
fk_schema, fk_table, fk_column = fk_column_path.split(".")
new_content += (
f"{INDENTATION*2}self.{column.name} = "
f"generic.column_value_provider.column_value(db_connection, "
f"generic.column_value_provider.column_value(dst_db_conn, "
f'"{fk_schema}", "{fk_table}", "{fk_column}"'
")\n"
)
Expand Down
8 changes: 4 additions & 4 deletions tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


class entityGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass


class personGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass
self.name = generic.text.color()
self.nhs_number = generic.text.color()
Expand All @@ -24,9 +24,9 @@ def __init__(self, db_connection):


class hospital_visitGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass
self.person_id = generic.column_value_provider.column_value(db_connection, "myschema", "person", "person_id")
self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id")
self.visit_start = generic.datetime.datetime()
self.visit_end = generic.datetime.date()
self.visit_duration_seconds = generic.numeric.float_number()
Expand Down
22 changes: 20 additions & 2 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch

from sqlsynthgen.create import create_db_data, create_db_tables
from sqlsynthgen.create import create_db_data, create_db_tables, populate
from tests.utils import get_test_settings


Expand All @@ -21,7 +21,7 @@ def test_create_db_data(self) -> None:
create_db_data([], [], 0)

mock_populate.assert_called_once()
mock_create_engine.assert_called_once()
mock_create_engine.assert_called()

def test_create_db_tables(self) -> None:
"""Test the create_tables function."""
Expand All @@ -36,3 +36,21 @@ def test_create_db_tables(self) -> None:
mock_create_engine.assert_called_once_with(
mock_get_settings.return_value.dst_postgres_dsn
)

def test_populate(self) -> None:
"""Test the populate function."""
with patch("sqlsynthgen.create.insert") as mock_insert:
mock_src_conn = MagicMock()
mock_dst_conn = MagicMock()
mock_gen = MagicMock()
tables = [None]
generators = [mock_gen]
populate(mock_src_conn, mock_dst_conn, tables, generators, 1)

mock_gen.assert_called_once_with(mock_src_conn, mock_dst_conn)
mock_insert.return_value.values.assert_called_once_with(
mock_gen.return_value.__dict__
)
mock_dst_conn.execute.assert_called_once_with(
mock_insert.return_value.values.return_value
)
3 changes: 2 additions & 1 deletion tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class BinaryProviderTestCase(TestCase):
"""Tests for the BytesProvider class."""

def test_bytes(self) -> None:
BytesProvider().bytes().decode("utf-8")
"""Test the bytes method."""
self.assertTrue(BytesProvider().bytes().decode("utf-8") != "")


@skipUnless(
Expand Down

0 comments on commit 1ed1542

Please sign in to comment.