Skip to content

Commit

Permalink
Make src & dst database params optionals (#90)
Browse files Browse the repository at this point in the history
* Remove src_conn param from generators

* Make src and dst db params optional

We want to be able to run SSG without a src db connection
in case the developer and end user are different people.

* Require src db params for make- commands

* Fix unit tests
  • Loading branch information
Iain-S authored May 25, 2023
1 parent d098b1e commit 4ec2edb
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 96 deletions.
20 changes: 4 additions & 16 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,10 @@ def create_db_data(
if settings.dst_schema
else create_engine(settings.dst_postgres_dsn)
)
src_engine = (
create_engine_with_search_path(
settings.src_postgres_dsn, settings.src_schema # type: ignore
)
if settings.src_schema is not None
else create_engine(settings.src_postgres_dsn)
)

with dst_engine.connect() as dst_conn, src_engine.connect() as src_conn:
with dst_engine.connect() as dst_conn:
for _ in range(num_passes):
populate(
src_conn,
dst_conn,
sorted_tables,
table_generator_dict,
Expand All @@ -93,7 +85,6 @@ def _populate_story(
story: Story,
table_dict: Dict[str, Any],
table_generator_dict: Dict[str, Any],
src_conn: Any,
dst_conn: Any,
) -> None:
"""Write to the database all the rows created by the given story."""
Expand All @@ -107,7 +98,7 @@ def _populate_story(
table = table_dict[table_name]
if table.name in table_generator_dict:
table_generator = table_generator_dict[table.name]
default_values = table_generator(src_conn, dst_conn).__dict__
default_values = table_generator(dst_conn).__dict__
else:
default_values = {}
insert_values = {**default_values, **provided_values}
Expand All @@ -123,7 +114,6 @@ def _populate_story(


def populate(
src_conn: Any,
dst_conn: Any,
tables: list,
table_generator_dict: dict,
Expand All @@ -145,7 +135,7 @@ def populate(
for story in stories:
# Run the inserts for each story within a transaction.
with dst_conn.begin():
_populate_story(story, table_dict, table_generator_dict, src_conn, dst_conn)
_populate_story(story, table_dict, table_generator_dict, dst_conn)

# Generate individual rows, table by table.
for table in tables:
Expand All @@ -157,7 +147,5 @@ def populate(
# Run all the inserts for one table in a transaction
with dst_conn.begin():
for _ in range(table_generator.num_rows_per_pass):
stmt = insert(table).values(
table_generator(src_conn, dst_conn).__dict__
)
stmt = insert(table).values(table_generator(dst_conn).__dict__)
dst_conn.execute(stmt)
21 changes: 18 additions & 3 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def make_generators(
print(f"{ssg_file} should not already exist. Exiting...", file=stderr)
sys.exit(1)

settings = get_settings()
src_dsn = settings.src_postgres_dsn
if src_dsn is None:
print("Missing source database connection details.", file=stderr)
sys.exit(1)

orm_module: ModuleType = import_file(orm_file)
generator_config = read_yaml_file(config_file) if config_file is not None else {}
result: str = make_table_generators(
Expand All @@ -147,11 +153,15 @@ def make_stats(
if stats_file_path.exists() and not force:
print(f"{stats_file} should not already exist. Exiting...", file=stderr)
sys.exit(1)
settings = get_settings()

config = read_yaml_file(config_file) if config_file is not None else {}

settings = get_settings()
src_dsn = settings.src_postgres_dsn
if src_dsn is None:
raise ValueError("Missing source database connection details.")
print("Missing source database connection details.", file=stderr)
sys.exit(1)

src_stats = make_src_stats(src_dsn, config)
stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8")

Expand Down Expand Up @@ -180,8 +190,13 @@ def make_tables(
sys.exit(1)

settings = get_settings()
if settings.src_postgres_dsn is None:
print("Missing source database connection details.", file=stderr)
sys.exit(1)

src_dsn = settings.src_postgres_dsn

content = make_tables_file(settings.src_postgres_dsn, settings.src_schema) # type: ignore
content = make_tables_file(src_dsn, settings.src_schema)
orm_file_path.write_text(content, encoding="utf-8")


Expand Down
72 changes: 29 additions & 43 deletions sqlsynthgen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,26 @@ class Settings(BaseSettings):
Connection database e.g. `postgres`
dst_ssl_required (bool) :
Flag `True` if db requires SSL
"""

# Connection parameters for the source PostgreSQL database. See also
# https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS
src_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0"
src_host_name: Optional[str] # e.g. "mydb.mydomain.com" or "0.0.0.0"
src_port: int = 5432
src_user_name: str # e.g. "postgres" or "myuser@mydb"
src_password: str
src_db_name: str
src_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
src_password: Optional[str]
src_db_name: Optional[str]
src_ssl_required: bool = False # whether the db requires SSL
src_schema: Optional[str]

# Connection parameters for the destination PostgreSQL database.
dst_host_name: str # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0"
dst_host_name: Optional[
str
] # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0"
dst_port: int = 5432
dst_user_name: str # e.g. "postgres" or "myuser@mydb"
dst_password: str
dst_db_name: str
dst_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
dst_password: Optional[str]
dst_db_name: Optional[str]
dst_schema: Optional[str]
dst_ssl_required: bool = False # whether the db requires SSL

Expand All @@ -83,42 +84,24 @@ class Settings(BaseSettings):
dst_postgres_dsn: Optional[PostgresDsn]

@validator("src_postgres_dsn", pre=True)
def validate_src_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the source db data source name.
Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. parameters for source database connection
Return:
(str): Validated data source name
"""
def validate_src_postgres_dsn(
cls, _: Optional[PostgresDsn], values: Any
) -> Optional[str]:
"""Create and validate the source db data source name."""
return cls.check_postgres_dsn(_, values, "src")

@validator("dst_postgres_dsn", pre=True)
def validate_dst_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the destination db data source name.
Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. Connection parameters for destination database
Return:
(str): Validated data source name
"""
def validate_dst_postgres_dsn(
cls, _: Optional[PostgresDsn], values: Any
) -> Optional[str]:
"""Create and validate the destination db data source name."""
return cls.check_postgres_dsn(_, values, "dst")

@staticmethod
def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> str:
"""Build a DSN string from the host, db name, port, username and password.
Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. Connection parameters
Return:
(str): A data source name
"""
def check_postgres_dsn(
_: Optional[PostgresDsn], values: Any, prefix: str
) -> Optional[str]:
"""Build a DSN string from the host, db name, port, username and password."""
# We want to build the Data Source Name ourselves so none should be provided
if _:
raise ValueError("postgres_dsn should not be provided")
Expand All @@ -129,12 +112,15 @@ def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> st
port = values[f"{prefix}_port"]
db_name = values[f"{prefix}_db_name"]

dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
if user and password and host and port and db_name:
dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"

if values[f"{prefix}_ssl_required"]:
return dsn + "?sslmode=require"

if values[f"{prefix}_ssl_required"]:
return dsn + "?sslmode=require"
return dsn

return dsn
return None

@dataclass
class Config:
Expand Down
2 changes: 1 addition & 1 deletion sqlsynthgen/templates/ssg.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f:
class {{ table_data.class_name }}:
num_rows_per_pass = {{ table_data.rows_per_pass }}

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
{% for column_data in table_data.columns %}
{% if column_data.primary_key %}
pass
Expand Down
8 changes: 4 additions & 4 deletions tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
class entityGenerator:
num_rows_per_pass = 1

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
pass


class personGenerator:
num_rows_per_pass = 2

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
self.name = generic.person.full_name()
self.stored_from = generic.datetime.datetime(start=2022, end=2022)
self.research_opt_out = row_generators.boolean_from_src_stats_generator(
Expand All @@ -57,15 +57,15 @@ def __init__(self, src_db_conn, dst_db_conn):
class test_entityGenerator:
num_rows_per_pass = 1

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
pass
self.single_letter_column = generic.person.password(1)


class hospital_visitGenerator:
num_rows_per_pass = 3

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
(
self.visit_start,
self.visit_end,
Expand Down
10 changes: 3 additions & 7 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
return story()

for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
mock_src_conn = MagicMock()
mock_dst_conn = MagicMock()
mock_dst_conn.execute.return_value.returned_defaults = {}
mock_table = MagicMock()
Expand All @@ -78,13 +77,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
if num_stories_per_pass > 0
else []
)
populate(
mock_src_conn, mock_dst_conn, tables, row_generators, story_generators
)
populate(mock_dst_conn, tables, row_generators, story_generators)

mock_gen.assert_has_calls(
[call(mock_src_conn, mock_dst_conn)]
* (num_stories_per_pass + num_rows_per_pass)
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass)
)
mock_insert.return_value.values.assert_has_calls(
[call(mock_gen.return_value.__dict__)]
Expand All @@ -110,7 +106,7 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
tables = [mock_table_one, mock_table_two, mock_table_three]
row_generators = {"two": mock_gen_two, "three": mock_gen_three}

populate(2, mock_dst_conn, tables, row_generators, [])
populate(mock_dst_conn, tables, row_generators, [])
self.assertListEqual(
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
)
Expand Down
Loading

0 comments on commit 4ec2edb

Please sign in to comment.