Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

column defaults #140

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
uses: actions/cache@v3
with:
path: ${{ env.PRE_COMMIT_HOME }}
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }}
- name: Install Pre-Commit Hooks
shell: bash
if: steps.pre-commit-cache.outputs.cache-hit != 'true'
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def _populate_story(
else:
default_values = {}
insert_values = {**default_values, **provided_values}
stmt = insert(table).values(insert_values)
stmt = insert(table).values(insert_values).return_defaults()
cursor = dst_conn.execute(stmt)
# We need to return all the default values etc. to the generator,
# because other parts of the story may refer to them.
if cursor.returned_defaults:
# pylint: disable=protected-access
return_values = dict(cursor.returned_defaults._mapping.items())
return_values = cursor.returned_defaults._mapping
# pylint: enable=protected-access
else:
return_values = {}
Expand Down
119 changes: 88 additions & 31 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Tests for the create module."""
import itertools as itt
from pathlib import Path
from typing import Any, Generator, Tuple
from unittest.mock import MagicMock, call, patch

from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.orm import declarative_base

from sqlsynthgen.create import (
Story,
_populate_story,
create_db_data,
create_db_tables,
create_db_vocab,
populate,
)
from tests.utils import SSGTestCase, get_test_settings
from tests.utils import RequiresDBTestCase, SSGTestCase, get_test_settings, run_psql


class MyTestCase(SSGTestCase):
Expand Down Expand Up @@ -48,8 +54,7 @@ def test_create_db_tables(
)
mock_meta.create_all.assert_called_once_with(mock_create_engine.return_value)

@patch("sqlsynthgen.create.insert")
def test_populate(self, mock_insert: MagicMock) -> None:
def test_populate(self) -> None:
"""Test the populate function."""
table_name = "table_name"

Expand All @@ -62,34 +67,47 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
return story()

for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
mock_dst_conn = MagicMock()
mock_dst_conn.execute.return_value.returned_defaults = {}
mock_table = MagicMock()
mock_table.name = table_name
mock_gen = MagicMock()
mock_gen.num_rows_per_pass = num_rows_per_pass
mock_gen.return_value = {}

tables = [mock_table]
row_generators = {table_name: mock_gen}
story_generators = (
[{"name": mock_story_gen, "num_stories_per_pass": num_stories_per_pass}]
if num_stories_per_pass > 0
else []
)
populate(mock_dst_conn, tables, row_generators, story_generators)

mock_gen.assert_has_calls(
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass)
)
mock_insert.return_value.values.assert_has_calls(
[call(mock_gen.return_value)]
* (num_stories_per_pass + num_rows_per_pass)
)
mock_dst_conn.execute.assert_has_calls(
[call(mock_insert.return_value.values.return_value)]
* (num_stories_per_pass + num_rows_per_pass)
)
with patch("sqlsynthgen.create.insert") as mock_insert:
mock_values = mock_insert.return_value.values
mock_dst_conn = MagicMock()
mock_dst_conn.execute.return_value.returned_defaults = {}
mock_table = MagicMock()
mock_table.name = table_name
mock_gen = MagicMock()
mock_gen.num_rows_per_pass = num_rows_per_pass
mock_gen.return_value = {}

tables = [mock_table]
row_generators = {table_name: mock_gen}
story_generators = (
[
{
"name": mock_story_gen,
"num_stories_per_pass": num_stories_per_pass,
}
]
if num_stories_per_pass > 0
else []
)
populate(mock_dst_conn, tables, row_generators, story_generators)

self.assertListEqual(
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
mock_gen.call_args_list,
)
self.assertListEqual(
[call(mock_gen.return_value)]
* (num_stories_per_pass + num_rows_per_pass),
mock_values.call_args_list,
)
self.assertListEqual(
(
[call(mock_values.return_value.return_defaults.return_value)]
* num_stories_per_pass
)
+ ([call(mock_values.return_value)] * num_rows_per_pass),
mock_dst_conn.execute.call_args_list,
)

@patch("sqlsynthgen.create.insert")
def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
Expand Down Expand Up @@ -131,3 +149,42 @@ def test_create_db_vocab(
)
# Running the same insert twice should be fine.
create_db_vocab(vocab_list)


class TestStoryDefaults(RequiresDBTestCase):
"""Test that we can handle column defaults in stories."""

# pylint: disable=invalid-name
Base = declarative_base()
# pylint: enable=invalid-name
metadata = Base.metadata

class ColumnDefaultsTable(Base): # type: ignore
"""A SQLAlchemy model."""

__tablename__ = "column_defaults"
someval = Column(Integer, primary_key=True)
otherval = Column(Integer, server_default="8")

def setUp(self) -> None:
"""Ensure we have an empty DB to work with."""
dump_file_path = Path("dst.dump")
examples_dir = Path("tests/examples")
run_psql(examples_dir / dump_file_path)

def test_populate(self) -> None:
"""Check that we can populate a table that has column defaults."""
engine = create_engine(
"postgresql://postgres:password@localhost:5432/dst",
)
self.metadata.create_all(engine)

def my_story() -> Story:
"""A story generator."""
first_row = yield "column_defaults", {}
self.assertEqual(1, first_row["someval"])
self.assertEqual(8, first_row["otherval"])

with engine.connect() as conn:
with conn.begin():
_populate_story(my_story(), dict(self.metadata.tables), {}, conn)