Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass src_conn and dst_conn to generators when we create-data #26

Merged
merged 5 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,23 @@ def create_db_tables(metadata: Any) -> Any:
def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None:
"""Connect to a database and populate it with data."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)
dst_engine = create_engine(settings.dst_postgres_dsn)
src_engine = create_engine(settings.src_postgres_dsn)

with engine.connect() as conn:
populate(conn, sorted_tables, sorted_generators, num_rows)
with dst_engine.connect() as dst_conn:
with src_engine.connect() as src_conn:
populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_rows)


def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None:
def populate(
src_conn: Any, dst_conn: Any, tables: list, generators: list, num_rows: int
) -> None:
"""Populate a database schema with dummy data."""

for table, generator in zip(tables, generators):
# Run all the inserts for one table in a transaction
with conn.begin():
for table, generator in zip(
tables, generators
): # Run all the inserts for one table in a transaction
with dst_conn.begin():
for _ in range(num_rows):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
stmt = insert(table).values(generator(src_conn, dst_conn).__dict__)
dst_conn.execute(stmt)
12 changes: 6 additions & 6 deletions sqlsynthgen/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
'"""This file was auto-generated by sqlsynthgen but can be edited manually."""',
"from mimesis import Generic",
"from mimesis.locales import Locale",
"from sqlsynthgen.providers import BinaryProvider, ForeignKeyProvider",
"from sqlsynthgen.providers import BytesProvider, ColumnValueProvider",
"",
"generic = Generic(locale=Locale.EN)",
"generic.add_provider(ForeignKeyProvider)",
"generic.add_provider(BinaryProvider)",
"generic.add_provider(ColumnValueProvider)",
"generic.add_provider(BytesProvider)",
"",
)
)
Expand Down Expand Up @@ -42,7 +42,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str:
sqltypes.DateTime: "generic.datetime.datetime()",
sqltypes.Float: "generic.numeric.float_number()",
sqltypes.Integer: "generic.numeric.integer_number()",
sqltypes.LargeBinary: "generic.binary_provider.bytes()",
sqltypes.LargeBinary: "generic.bytes_provider.bytes()",
sqltypes.Numeric: "generic.numeric.float_number()",
sqltypes.String: "generic.text.color()",
sqltypes.Text: "generic.text.color()",
Expand All @@ -56,7 +56,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str:
+ new_class_name
+ ":\n"
+ INDENTATION
+ "def __init__(self, db_connection):\n"
+ "def __init__(self, src_db_conn, dst_db_conn):\n"
)

for column in table.columns:
Expand All @@ -72,7 +72,7 @@ def make_generators_from_tables(tables_module: ModuleType) -> str:
fk_schema, fk_table, fk_column = fk_column_path.split(".")
new_content += (
f"{INDENTATION*2}self.{column.name} = "
f"generic.foreign_key_provider.key(db_connection, "
f"generic.column_value_provider.column_value(dst_db_conn, "
f'"{fk_schema}", "{fk_table}", "{fk_column}"'
")\n"
)
Expand Down
24 changes: 11 additions & 13 deletions sqlsynthgen/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,33 @@

from mimesis import Text
from mimesis.providers.base import BaseDataProvider, BaseProvider

# from mimesis.locales import Locale
from sqlalchemy.sql import text

# generic = Generic(locale=Locale.EN)


class ForeignKeyProvider(BaseProvider):
"""A Mimesis provider of foreign keys."""
class ColumnValueProvider(BaseProvider):
"""A Mimesis provider of random values from the source database."""

class Meta:
"""Meta-class for ForeignKeyProvider settings."""
"""Meta-class for ColumnValueProvider settings."""

name = "foreign_key_provider"
name = "column_value_provider"

def key(self, db_connection: Any, schema: str, table: str, column: str) -> Any:
"""Return a random value from the table and column specified."""
def column_value(
self, db_connection: Any, schema: str, table: str, column: str
) -> Any:
"""Return a random value from the column specified."""
query_str = f"SELECT {column} FROM {schema}.{table} ORDER BY random() LIMIT 1"
key = db_connection.execute(text(query_str)).fetchone()[0]
return key


class BinaryProvider(BaseDataProvider):
class BytesProvider(BaseDataProvider):
"""A Mimesis provider of binary data."""

class Meta:
"""Meta-class for ForeignKeyProvider settings."""
"""Meta-class for BytesProvider settings."""

name = "binary_provider"
name = "bytes_provider"

def bytes(self) -> bytes:
"""Return a UTF-8 encoded sentence."""
Expand Down
16 changes: 8 additions & 8 deletions tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""This file was auto-generated by sqlsynthgen but can be edited manually."""
from mimesis import Generic
from mimesis.locales import Locale
from sqlsynthgen.providers import BinaryProvider, ForeignKeyProvider
from sqlsynthgen.providers import BytesProvider, ColumnValueProvider

generic = Generic(locale=Locale.EN)
generic.add_provider(ForeignKeyProvider)
generic.add_provider(BinaryProvider)
generic.add_provider(ColumnValueProvider)
generic.add_provider(BytesProvider)


class entityGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass


class personGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass
self.name = generic.text.color()
self.nhs_number = generic.text.color()
Expand All @@ -24,13 +24,13 @@ def __init__(self, db_connection):


class hospital_visitGenerator:
def __init__(self, db_connection):
def __init__(self, src_db_conn, dst_db_conn):
pass
self.person_id = generic.foreign_key_provider.key(db_connection, "myschema", "person", "person_id")
self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id")
self.visit_start = generic.datetime.datetime()
self.visit_end = generic.datetime.date()
self.visit_duration_seconds = generic.numeric.float_number()
self.visit_image = generic.binary_provider.bytes()
self.visit_image = generic.bytes_provider.bytes()


sorted_generators = [
Expand Down
52 changes: 52 additions & 0 deletions tests/examples/providers.dump
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
--
-- PostgreSQL database dump
--

-- Dumped from database version 14.2 (Debian 14.2-1.pgdg110+1)
-- Dumped by pg_dump version 14.6 (Homebrew)

SET statement_timeout = 0;
SET lock_timeout = 0;
SET idle_in_transaction_session_timeout = 0;
SET client_encoding = 'UTF8';
SET standard_conforming_strings = on;
SELECT pg_catalog.set_config('search_path', '', false);
SET check_function_bodies = false;
SET xmloption = content;
SET client_min_messages = warning;
SET row_security = off;

DROP DATABASE IF EXISTS providers;
--
-- Name: providers; Type: DATABASE; Schema: -; Owner: postgres
--

CREATE DATABASE providers WITH TEMPLATE = template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8';


ALTER DATABASE providers OWNER TO postgres;

\connect providers

SET statement_timeout = 0;
SET lock_timeout = 0;
SET idle_in_transaction_session_timeout = 0;
SET client_encoding = 'UTF8';
SET standard_conforming_strings = on;
SELECT pg_catalog.set_config('search_path', '', false);
SET check_function_bodies = false;
SET xmloption = content;
SET client_min_messages = warning;
SET row_security = off;

SET default_tablespace = '';

SET default_table_access_method = heap;

--
-- Name: patient; Type: TABLE; Schema: public; Owner: postgres
--

CREATE TABLE public.patient (
sex text NOT NULL
);
22 changes: 20 additions & 2 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch

from sqlsynthgen.create import create_db_data, create_db_tables
from sqlsynthgen.create import create_db_data, create_db_tables, populate
from tests.utils import get_test_settings


Expand All @@ -21,7 +21,7 @@ def test_create_db_data(self) -> None:
create_db_data([], [], 0)

mock_populate.assert_called_once()
mock_create_engine.assert_called_once()
mock_create_engine.assert_called()

def test_create_db_tables(self) -> None:
"""Test the create_tables function."""
Expand All @@ -36,3 +36,21 @@ def test_create_db_tables(self) -> None:
mock_create_engine.assert_called_once_with(
mock_get_settings.return_value.dst_postgres_dsn
)

def test_populate(self) -> None:
"""Test the populate function."""
with patch("sqlsynthgen.create.insert") as mock_insert:
mock_src_conn = MagicMock()
mock_dst_conn = MagicMock()
mock_gen = MagicMock()
tables = [None]
generators = [mock_gen]
populate(mock_src_conn, mock_dst_conn, tables, generators, 1)

mock_gen.assert_called_once_with(mock_src_conn, mock_dst_conn)
mock_insert.return_value.values.assert_called_once_with(
mock_gen.return_value.__dict__
)
mock_dst_conn.execute.assert_called_once_with(
mock_insert.return_value.values.return_value
)
26 changes: 4 additions & 22 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for the main module."""
"""Tests for the CLI."""
import os
from pathlib import Path
from subprocess import run
from unittest import TestCase, skipUnless

from tests.utils import run_psql


@skipUnless(
os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable."
Expand All @@ -19,27 +21,7 @@ def setUp(self) -> None:
self.orm_file_path.unlink(missing_ok=True)
self.ssg_file_path.unlink(missing_ok=True)

# If you need to update src.dump or dst.dump, use
# pg_dump -d src|dst -h localhost -U postgres -C -c > tests/examples/src|dst.dump

env = os.environ.copy()
env = {**env, "PGPASSWORD": "password"}

# Clear and re-create the destination database
completed_process = run(
[
"psql",
"--host=localhost",
"--username=postgres",
"--file=" + str(Path("tests/examples/dst.dump")),
],
capture_output=True,
env=env,
check=True,
)

# psql doesn't always return != 0 if it fails
assert completed_process.stderr == b"", completed_process.stderr
run_psql("dst.dump")

def test_workflow(self) -> None:
"""Test the recommended CLI workflow runs without errors."""
Expand Down
64 changes: 64 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Tests for the providers module."""
import os
from unittest import TestCase, skipUnless

from sqlalchemy import Column, Integer, Text, create_engine, insert
from sqlalchemy.ext.declarative import declarative_base

from sqlsynthgen.providers import BytesProvider, ColumnValueProvider
from tests.utils import run_psql

# pylint: disable=invalid-name
Base = declarative_base()
# pylint: enable=invalid-name
metadata = Base.metadata


class Person(Base): # type: ignore
"""A SQLAlchemy table."""

__tablename__ = "person"
person_id = Column(
Integer,
primary_key=True,
)
# We don't actually need a foreign key constraint to test this
sex = Column(Text)


class BinaryProviderTestCase(TestCase):
"""Tests for the BytesProvider class."""

def test_bytes(self) -> None:
"""Test the bytes method."""
self.assertTrue(BytesProvider().bytes().decode("utf-8") != "")


@skipUnless(
os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable."
)
class ColumnValueProviderTestCase(TestCase):
"""Tests for the ColumnValueProvider class."""

def setUp(self) -> None:
"""Pre-test setup."""

run_psql("providers.dump")

self.engine = create_engine(
"postgresql://postgres:password@localhost:5432/providers"
)
metadata.create_all(self.engine)

def test_column_value(self) -> None:
"""Test the key method."""
# pylint: disable=invalid-name

with self.engine.connect() as conn:
stmt = insert(Person).values(sex="M")
conn.execute(stmt)

provider = ColumnValueProvider()
key = provider.column_value(conn, "public", "person", "sex")

self.assertEqual("M", key)
Loading