Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make src & dst database params optionals #90

Merged
merged 4 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,10 @@ def create_db_data(
if settings.dst_schema
else create_engine(settings.dst_postgres_dsn)
)
src_engine = (
create_engine_with_search_path(
settings.src_postgres_dsn, settings.src_schema # type: ignore
)
if settings.src_schema is not None
else create_engine(settings.src_postgres_dsn)
)

with dst_engine.connect() as dst_conn, src_engine.connect() as src_conn:
with dst_engine.connect() as dst_conn:
for _ in range(num_passes):
populate(
src_conn,
dst_conn,
sorted_tables,
table_generator_dict,
Expand All @@ -93,7 +85,6 @@ def _populate_story(
story: Story,
table_dict: Dict[str, Any],
table_generator_dict: Dict[str, Any],
src_conn: Any,
dst_conn: Any,
) -> None:
"""Write to the database all the rows created by the given story."""
Expand All @@ -107,7 +98,7 @@ def _populate_story(
table = table_dict[table_name]
if table.name in table_generator_dict:
table_generator = table_generator_dict[table.name]
default_values = table_generator(src_conn, dst_conn).__dict__
default_values = table_generator(dst_conn).__dict__
else:
default_values = {}
insert_values = {**default_values, **provided_values}
Expand All @@ -123,7 +114,6 @@ def _populate_story(


def populate(
src_conn: Any,
dst_conn: Any,
tables: list,
table_generator_dict: dict,
Expand All @@ -145,7 +135,7 @@ def populate(
for story in stories:
# Run the inserts for each story within a transaction.
with dst_conn.begin():
_populate_story(story, table_dict, table_generator_dict, src_conn, dst_conn)
_populate_story(story, table_dict, table_generator_dict, dst_conn)

# Generate individual rows, table by table.
for table in tables:
Expand All @@ -157,7 +147,5 @@ def populate(
# Run all the inserts for one table in a transaction
with dst_conn.begin():
for _ in range(table_generator.num_rows_per_pass):
stmt = insert(table).values(
table_generator(src_conn, dst_conn).__dict__
)
stmt = insert(table).values(table_generator(dst_conn).__dict__)
dst_conn.execute(stmt)
21 changes: 18 additions & 3 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def make_generators(
print(f"{ssg_file} should not already exist. Exiting...", file=stderr)
sys.exit(1)

settings = get_settings()
src_dsn = settings.src_postgres_dsn
if src_dsn is None:
print("Missing source database connection details.", file=stderr)
sys.exit(1)

orm_module: ModuleType = import_file(orm_file)
generator_config = read_yaml_file(config_file) if config_file is not None else {}
result: str = make_table_generators(
Expand All @@ -147,11 +153,15 @@ def make_stats(
if stats_file_path.exists() and not force:
print(f"{stats_file} should not already exist. Exiting...", file=stderr)
sys.exit(1)
settings = get_settings()

config = read_yaml_file(config_file) if config_file is not None else {}

settings = get_settings()
src_dsn = settings.src_postgres_dsn
if src_dsn is None:
raise ValueError("Missing source database connection details.")
print("Missing source database connection details.", file=stderr)
sys.exit(1)

src_stats = make_src_stats(src_dsn, config)
stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8")

Expand Down Expand Up @@ -180,8 +190,13 @@ def make_tables(
sys.exit(1)

settings = get_settings()
if settings.src_postgres_dsn is None:
print("Missing source database connection details.", file=stderr)
sys.exit(1)

src_dsn = settings.src_postgres_dsn

content = make_tables_file(settings.src_postgres_dsn, settings.src_schema) # type: ignore
content = make_tables_file(src_dsn, settings.src_schema)
orm_file_path.write_text(content, encoding="utf-8")


Expand Down
72 changes: 29 additions & 43 deletions sqlsynthgen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,26 @@ class Settings(BaseSettings):
Connection database e.g. `postgres`
dst_ssl_required (bool) :
Flag `True` if db requires SSL

"""

# Connection parameters for the source PostgreSQL database. See also
# https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS
src_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0"
src_host_name: Optional[str] # e.g. "mydb.mydomain.com" or "0.0.0.0"
src_port: int = 5432
src_user_name: str # e.g. "postgres" or "myuser@mydb"
src_password: str
src_db_name: str
src_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
src_password: Optional[str]
src_db_name: Optional[str]
src_ssl_required: bool = False # whether the db requires SSL
src_schema: Optional[str]

# Connection parameters for the destination PostgreSQL database.
dst_host_name: str # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0"
dst_host_name: Optional[
str
] # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0"
dst_port: int = 5432
dst_user_name: str # e.g. "postgres" or "myuser@mydb"
dst_password: str
dst_db_name: str
dst_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
dst_password: Optional[str]
dst_db_name: Optional[str]
dst_schema: Optional[str]
dst_ssl_required: bool = False # whether the db requires SSL

Expand All @@ -83,42 +84,24 @@ class Settings(BaseSettings):
dst_postgres_dsn: Optional[PostgresDsn]

@validator("src_postgres_dsn", pre=True)
def validate_src_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the source db data source name.

Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. parameters for source database connection

Return:
(str): Validated data source name
"""
def validate_src_postgres_dsn(
cls, _: Optional[PostgresDsn], values: Any
) -> Optional[str]:
"""Create and validate the source db data source name."""
return cls.check_postgres_dsn(_, values, "src")

@validator("dst_postgres_dsn", pre=True)
def validate_dst_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the destination db data source name.

Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. Connection parameters for destination database

Return:
(str): Validated data source name
"""
def validate_dst_postgres_dsn(
cls, _: Optional[PostgresDsn], values: Any
) -> Optional[str]:
"""Create and validate the destination db data source name."""
return cls.check_postgres_dsn(_, values, "dst")

@staticmethod
def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> str:
"""Build a DSN string from the host, db name, port, username and password.

Args:
cls (Settings): Self Settings instance
values (Optional Any): Eg. Connection parameters

Return:
(str): A data source name
"""
def check_postgres_dsn(
_: Optional[PostgresDsn], values: Any, prefix: str
) -> Optional[str]:
"""Build a DSN string from the host, db name, port, username and password."""
# We want to build the Data Source Name ourselves so none should be provided
if _:
raise ValueError("postgres_dsn should not be provided")
Expand All @@ -129,12 +112,15 @@ def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> st
port = values[f"{prefix}_port"]
db_name = values[f"{prefix}_db_name"]

dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
if user and password and host and port and db_name:
dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"

if values[f"{prefix}_ssl_required"]:
return dsn + "?sslmode=require"

if values[f"{prefix}_ssl_required"]:
return dsn + "?sslmode=require"
return dsn

return dsn
return None

@dataclass
class Config:
Expand Down
2 changes: 1 addition & 1 deletion sqlsynthgen/templates/ssg.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f:
class {{ table_data.class_name }}:
num_rows_per_pass = {{ table_data.rows_per_pass }}

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
{% for column_data in table_data.columns %}
{% if column_data.primary_key %}
pass
Expand Down
8 changes: 4 additions & 4 deletions tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
class entityGenerator:
num_rows_per_pass = 1

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
pass


class personGenerator:
num_rows_per_pass = 2

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
self.name = generic.person.full_name()
self.stored_from = generic.datetime.datetime(start=2022, end=2022)
self.research_opt_out = row_generators.boolean_from_src_stats_generator(
Expand All @@ -57,15 +57,15 @@ def __init__(self, src_db_conn, dst_db_conn):
class test_entityGenerator:
num_rows_per_pass = 1

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
pass
self.single_letter_column = generic.person.password(1)


class hospital_visitGenerator:
num_rows_per_pass = 3

def __init__(self, src_db_conn, dst_db_conn):
def __init__(self, dst_db_conn):
(
self.visit_start,
self.visit_end,
Expand Down
10 changes: 3 additions & 7 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
return story()

for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
mock_src_conn = MagicMock()
mock_dst_conn = MagicMock()
mock_dst_conn.execute.return_value.returned_defaults = {}
mock_table = MagicMock()
Expand All @@ -78,13 +77,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
if num_stories_per_pass > 0
else []
)
populate(
mock_src_conn, mock_dst_conn, tables, row_generators, story_generators
)
populate(mock_dst_conn, tables, row_generators, story_generators)

mock_gen.assert_has_calls(
[call(mock_src_conn, mock_dst_conn)]
* (num_stories_per_pass + num_rows_per_pass)
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass)
)
mock_insert.return_value.values.assert_has_calls(
[call(mock_gen.return_value.__dict__)]
Expand All @@ -110,7 +106,7 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
tables = [mock_table_one, mock_table_two, mock_table_three]
row_generators = {"two": mock_gen_two, "three": mock_gen_three}

populate(2, mock_dst_conn, tables, row_generators, [])
populate(mock_dst_conn, tables, row_generators, [])
self.assertListEqual(
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
)
Expand Down
Loading