Skip to content

Commit

Permalink
Merge pull request #162 from alan-turing-institute/log-row-counts
Browse files Browse the repository at this point in the history
Log row counts for create-data
  • Loading branch information
mhauru authored Dec 4, 2023
2 parents f5a66b7 + 79a77f8 commit 362452a
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 70 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlsynthgen"
version = "0.4.0"
version = "0.4.1"
description = "Synthetic SQL data generator"
authors = ["Iain <[email protected]>"]
license = "MIT"
Expand Down
26 changes: 19 additions & 7 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions and classes to create and populate the target database."""
from collections import Counter
from typing import Any, Generator, Mapping, Sequence, Tuple

from sqlalchemy import Connection, insert
Expand All @@ -10,6 +11,7 @@
from sqlsynthgen.utils import create_db_engine, get_sync_engine, logger

Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
RowCounts = Counter[str]


def create_db_tables(metadata: MetaData) -> None:
Expand Down Expand Up @@ -57,7 +59,7 @@ def create_db_data(
table_generator_dict: Mapping[str, TableGenerator],
story_generator_list: Sequence[Mapping[str, Any]],
num_passes: int,
) -> None:
) -> RowCounts:
"""Connect to a database and populate it with data."""
settings = get_settings()
dst_dsn: str = settings.dst_dsn or ""
Expand All @@ -67,27 +69,30 @@ def create_db_data(
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
)

row_counts: Counter[str] = Counter()
with dst_engine.connect() as dst_conn:
for _ in range(num_passes):
populate(
row_counts += populate(
dst_conn,
sorted_tables,
table_generator_dict,
story_generator_list,
)
return row_counts


def _populate_story(
story: Story,
table_dict: Mapping[str, Table],
table_generator_dict: Mapping[str, TableGenerator],
dst_conn: Connection,
) -> None:
) -> RowCounts:
"""Write to the database all the rows created by the given story."""
# Loop over the rows generated by the story, insert them into their
# respective tables. Ideally this would say
# `for table_name, provided_values in story:`
# but we have to loop more manually to be able to use the `send` function.
row_counts: Counter[str] = Counter()
try:
table_name, provided_values = next(story)
while True:
Expand All @@ -111,19 +116,22 @@ def _populate_story(
else:
return_values = {}
final_values = {**insert_values, **return_values}
row_counts[table_name] = row_counts.get(table_name, 0) + 1
table_name, provided_values = story.send(final_values)
except StopIteration:
# The story has finished, it has no more rows to generate
pass
return row_counts


def populate(
dst_conn: Connection,
tables: Sequence[Table],
table_generator_dict: Mapping[str, TableGenerator],
story_generator_list: Sequence[Mapping[str, Any]],
) -> None:
) -> RowCounts:
"""Populate a database schema with synthetic data."""
row_counts: Counter[str] = Counter()
table_dict = {table.name: table for table in tables}
# Generate stories
# Each story generator returns a python generator (an unfortunate naming clash with
Expand All @@ -141,9 +149,11 @@ def populate(
)
for name, story in stories:
# Run the inserts for each story within a transaction.
logger.debug("Generating data for story %s", name)
logger.debug('Generating data for story "%s".', name)
with dst_conn.begin():
_populate_story(story, table_dict, table_generator_dict, dst_conn)
row_counts += _populate_story(
story, table_dict, table_generator_dict, dst_conn
)

# Generate individual rows, table by table.
for table in tables:
Expand All @@ -154,9 +164,11 @@ def populate(
table_generator = table_generator_dict[table.name]
if table_generator.num_rows_per_pass == 0:
continue
logger.debug("Generating data for table %s", table.name)
logger.debug('Generating data for table "%s".', table.name)
# Run all the inserts for one table in a transaction
with dst_conn.begin():
for _ in range(table_generator.num_rows_per_pass):
stmt = insert(table).values(table_generator(dst_conn))
dst_conn.execute(stmt)
row_counts[table.name] = row_counts.get(table.name, 0) + 1
return row_counts
13 changes: 11 additions & 2 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,22 @@ def create_data(
orm_metadata = get_orm_metadata(orm_module, tables_config)
table_generator_dict = ssg_module.table_generator_dict
story_generator_list = ssg_module.story_generator_list
create_db_data(
row_counts = create_db_data(
orm_metadata.sorted_tables,
table_generator_dict,
story_generator_list,
num_passes,
)
logger.debug("Data created in %s passes.", num_passes)
logger.debug(
"Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes"
)
for table_name, row_count in row_counts.items():
logger.debug(
"%s: %s %s created.",
table_name,
row_count,
"row" if row_count == 1 else "rows",
)


@app.command()
Expand Down
4 changes: 2 additions & 2 deletions sqlsynthgen/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def remove_db_data(
for table in reversed(metadata.sorted_tables):
# We presume that all tables that aren't vocab should be truncated
if table.name not in ssg_module.vocab_dict:
logger.debug("Truncating table %s", table.name)
logger.debug('Truncating table "%s".', table.name)
dst_conn.execute(delete(table))
dst_conn.commit()

Expand All @@ -50,7 +50,7 @@ def remove_db_vocab(
for table in reversed(metadata.sorted_tables):
# We presume that all tables that are vocab should be truncated
if table.name in ssg_module.vocab_dict:
logger.debug("Truncating vocabulary table %s", table.name)
logger.debug('Truncating vocabulary table "%s".', table.name)
dst_conn.execute(delete(table))
dst_conn.commit()

Expand Down
36 changes: 27 additions & 9 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the create module."""
import itertools as itt
from collections import Counter
from pathlib import Path
from typing import Any, Generator, Tuple
from unittest.mock import MagicMock, call, patch
Expand Down Expand Up @@ -34,11 +35,13 @@ def test_create_db_data(
) -> None:
"""Test the generate function."""
mock_get_settings.return_value = get_test_settings()
mock_populate.return_value = {}

num_passes = 23
create_db_data([], {}, [], num_passes)
row_counts = create_db_data([], {}, [], num_passes)

self.assertEqual(len(mock_populate.call_args_list), num_passes)
self.assertEqual(row_counts, {})
mock_create_engine.assert_called()

@patch("sqlsynthgen.create.get_settings")
Expand All @@ -62,13 +65,15 @@ def test_populate(self) -> None:

def story() -> Generator[Tuple[str, dict], None, None]:
"""Mock story."""
yield "table_name", {}
yield table_name, {}

def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
"""A function that returns mock stories."""
return story()

for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
for num_stories_per_pass, num_rows_per_pass, num_initial_rows in itt.product(
[0, 2], [0, 3], [0, 17]
):
with patch("sqlsynthgen.create.insert") as mock_insert:
mock_values = mock_insert.return_value.values
mock_dst_conn = MagicMock(spec=Connection)
Expand All @@ -78,9 +83,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
mock_gen = MagicMock(spec=TableGenerator)
mock_gen.num_rows_per_pass = num_rows_per_pass
mock_gen.return_value = {}
row_counts = Counter(
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
)

tables: list[Table] = [mock_table]
row_generators: dict[str, TableGenerator] = {table_name: mock_gen}
story_generators: list[dict[str, Any]] = (
[
{
Expand All @@ -92,13 +98,24 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
if num_stories_per_pass > 0
else []
)
populate(
row_counts += populate(
mock_dst_conn,
tables,
row_generators,
[mock_table],
{table_name: mock_gen},
story_generators,
)

expected_row_count = (
num_stories_per_pass + num_rows_per_pass + num_initial_rows
)
self.assertEqual(
Counter(
{table_name: expected_row_count}
if expected_row_count > 0
else {}
),
row_counts,
)
self.assertListEqual(
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
mock_gen.call_args_list,
Expand Down Expand Up @@ -135,7 +152,8 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
"three": mock_gen_three,
}

populate(mock_dst_conn, tables, row_generators, [])
row_counts = populate(mock_dst_conn, tables, row_generators, [])
self.assertEqual(row_counts, {"two": 1, "three": 1})
self.assertListEqual(
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
)
Expand Down
Loading

0 comments on commit 362452a

Please sign in to comment.