diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 9413eb2..00fa11b 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -70,18 +70,10 @@ def create_db_data( if settings.dst_schema else create_engine(settings.dst_postgres_dsn) ) - src_engine = ( - create_engine_with_search_path( - settings.src_postgres_dsn, settings.src_schema # type: ignore - ) - if settings.src_schema is not None - else create_engine(settings.src_postgres_dsn) - ) - with dst_engine.connect() as dst_conn, src_engine.connect() as src_conn: + with dst_engine.connect() as dst_conn: for _ in range(num_passes): populate( - src_conn, dst_conn, sorted_tables, table_generator_dict, @@ -93,7 +85,6 @@ def _populate_story( story: Story, table_dict: Dict[str, Any], table_generator_dict: Dict[str, Any], - src_conn: Any, dst_conn: Any, ) -> None: """Write to the database all the rows created by the given story.""" @@ -107,7 +98,7 @@ def _populate_story( table = table_dict[table_name] if table.name in table_generator_dict: table_generator = table_generator_dict[table.name] - default_values = table_generator(src_conn, dst_conn).__dict__ + default_values = table_generator(dst_conn).__dict__ else: default_values = {} insert_values = {**default_values, **provided_values} @@ -123,7 +114,6 @@ def _populate_story( def populate( - src_conn: Any, dst_conn: Any, tables: list, table_generator_dict: dict, @@ -145,7 +135,7 @@ def populate( for story in stories: # Run the inserts for each story within a transaction. with dst_conn.begin(): - _populate_story(story, table_dict, table_generator_dict, src_conn, dst_conn) + _populate_story(story, table_dict, table_generator_dict, dst_conn) # Generate individual rows, table by table. for table in tables: @@ -157,7 +147,5 @@ def populate( # Run all the inserts for one table in a transaction with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values( - table_generator(src_conn, dst_conn).__dict__ - ) + stmt = insert(table).values(table_generator(dst_conn).__dict__) dst_conn.execute(stmt) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 13427bb..28ec9e8 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -123,6 +123,12 @@ def make_generators( print(f"{ssg_file} should not already exist. Exiting...", file=stderr) sys.exit(1) + settings = get_settings() + src_dsn = settings.src_postgres_dsn + if src_dsn is None: + print("Missing source database connection details.", file=stderr) + sys.exit(1) + orm_module: ModuleType = import_file(orm_file) generator_config = read_yaml_file(config_file) if config_file is not None else {} result: str = make_table_generators( @@ -147,11 +153,15 @@ def make_stats( if stats_file_path.exists() and not force: print(f"{stats_file} should not already exist. Exiting...", file=stderr) sys.exit(1) - settings = get_settings() + config = read_yaml_file(config_file) if config_file is not None else {} + + settings = get_settings() src_dsn = settings.src_postgres_dsn if src_dsn is None: - raise ValueError("Missing source database connection details.") + print("Missing source database connection details.", file=stderr) + sys.exit(1) + src_stats = make_src_stats(src_dsn, config) stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8") @@ -180,8 +190,13 @@ def make_tables( sys.exit(1) settings = get_settings() + if settings.src_postgres_dsn is None: + print("Missing source database connection details.", file=stderr) + sys.exit(1) + + src_dsn = settings.src_postgres_dsn - content = make_tables_file(settings.src_postgres_dsn, settings.src_schema) # type: ignore + content = make_tables_file(src_dsn, settings.src_schema) orm_file_path.write_text(content, encoding="utf-8") diff --git a/sqlsynthgen/settings.py b/sqlsynthgen/settings.py index 6ddb0cb..2eae098 100644 --- a/sqlsynthgen/settings.py +++ b/sqlsynthgen/settings.py @@ -56,25 +56,26 @@ class Settings(BaseSettings): Connection database e.g. `postgres` dst_ssl_required (bool) : Flag `True` if db requires SSL - """ # Connection parameters for the source PostgreSQL database. See also # https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS - src_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0" + src_host_name: Optional[str] # e.g. "mydb.mydomain.com" or "0.0.0.0" src_port: int = 5432 - src_user_name: str # e.g. "postgres" or "myuser@mydb" - src_password: str - src_db_name: str + src_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb" + src_password: Optional[str] + src_db_name: Optional[str] src_ssl_required: bool = False # whether the db requires SSL src_schema: Optional[str] # Connection parameters for the destination PostgreSQL database. - dst_host_name: str # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0" + dst_host_name: Optional[ + str + ] # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0" dst_port: int = 5432 - dst_user_name: str # e.g. "postgres" or "myuser@mydb" - dst_password: str - dst_db_name: str + dst_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb" + dst_password: Optional[str] + dst_db_name: Optional[str] dst_schema: Optional[str] dst_ssl_required: bool = False # whether the db requires SSL @@ -83,42 +84,24 @@ class Settings(BaseSettings): dst_postgres_dsn: Optional[PostgresDsn] @validator("src_postgres_dsn", pre=True) - def validate_src_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str: - """Create and validate the source db data source name. - - Args: - cls (Settings): Self Settings instance - values (Optional Any): Eg. parameters for source database connection - - Return: - (str): Validated data source name - """ + def validate_src_postgres_dsn( + cls, _: Optional[PostgresDsn], values: Any + ) -> Optional[str]: + """Create and validate the source db data source name.""" return cls.check_postgres_dsn(_, values, "src") @validator("dst_postgres_dsn", pre=True) - def validate_dst_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str: - """Create and validate the destination db data source name. - - Args: - cls (Settings): Self Settings instance - values (Optional Any): Eg. Connection parameters for destination database - - Return: - (str): Validated data source name - """ + def validate_dst_postgres_dsn( + cls, _: Optional[PostgresDsn], values: Any + ) -> Optional[str]: + """Create and validate the destination db data source name.""" return cls.check_postgres_dsn(_, values, "dst") @staticmethod - def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> str: - """Build a DSN string from the host, db name, port, username and password. - - Args: - cls (Settings): Self Settings instance - values (Optional Any): Eg. Connection parameters - - Return: - (str): A data source name - """ + def check_postgres_dsn( + _: Optional[PostgresDsn], values: Any, prefix: str + ) -> Optional[str]: + """Build a DSN string from the host, db name, port, username and password.""" # We want to build the Data Source Name ourselves so none should be provided if _: raise ValueError("postgres_dsn should not be provided") @@ -129,12 +112,15 @@ def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> st port = values[f"{prefix}_port"] db_name = values[f"{prefix}_db_name"] - dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}" + if user and password and host and port and db_name: + dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}" + + if values[f"{prefix}_ssl_required"]: + return dsn + "?sslmode=require" - if values[f"{prefix}_ssl_required"]: - return dsn + "?sslmode=require" + return dsn - return dsn + return None @dataclass class Config: diff --git a/sqlsynthgen/templates/ssg.py.j2 b/sqlsynthgen/templates/ssg.py.j2 index 44c9830..244e994 100644 --- a/sqlsynthgen/templates/ssg.py.j2 +++ b/sqlsynthgen/templates/ssg.py.j2 @@ -32,7 +32,7 @@ with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f: class {{ table_data.class_name }}: num_rows_per_pass = {{ table_data.rows_per_pass }} - def __init__(self, src_db_conn, dst_db_conn): + def __init__(self, dst_db_conn): {% for column_data in table_data.columns %} {% if column_data.primary_key %} pass diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 0ffad68..dd74e6f 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -36,14 +36,14 @@ class entityGenerator: num_rows_per_pass = 1 - def __init__(self, src_db_conn, dst_db_conn): + def __init__(self, dst_db_conn): pass class personGenerator: num_rows_per_pass = 2 - def __init__(self, src_db_conn, dst_db_conn): + def __init__(self, dst_db_conn): self.name = generic.person.full_name() self.stored_from = generic.datetime.datetime(start=2022, end=2022) self.research_opt_out = row_generators.boolean_from_src_stats_generator( @@ -57,7 +57,7 @@ def __init__(self, src_db_conn, dst_db_conn): class test_entityGenerator: num_rows_per_pass = 1 - def __init__(self, src_db_conn, dst_db_conn): + def __init__(self, dst_db_conn): pass self.single_letter_column = generic.person.password(1) @@ -65,7 +65,7 @@ def __init__(self, src_db_conn, dst_db_conn): class hospital_visitGenerator: num_rows_per_pass = 3 - def __init__(self, src_db_conn, dst_db_conn): + def __init__(self, dst_db_conn): ( self.visit_start, self.visit_end, diff --git a/tests/test_create.py b/tests/test_create.py index d1fde47..1a5b8ed 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -62,7 +62,6 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: return story() for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]): - mock_src_conn = MagicMock() mock_dst_conn = MagicMock() mock_dst_conn.execute.return_value.returned_defaults = {} mock_table = MagicMock() @@ -78,13 +77,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: if num_stories_per_pass > 0 else [] ) - populate( - mock_src_conn, mock_dst_conn, tables, row_generators, story_generators - ) + populate(mock_dst_conn, tables, row_generators, story_generators) mock_gen.assert_has_calls( - [call(mock_src_conn, mock_dst_conn)] - * (num_stories_per_pass + num_rows_per_pass) + [call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass) ) mock_insert.return_value.values.assert_has_calls( [call(mock_gen.return_value.__dict__)] @@ -110,7 +106,7 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None: tables = [mock_table_one, mock_table_two, mock_table_three] row_generators = {"two": mock_gen_two, "three": mock_gen_three} - populate(2, mock_dst_conn, tables, row_generators, []) + populate(mock_dst_conn, tables, row_generators, []) self.assertListEqual( [call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list ) diff --git a/tests/test_main.py b/tests/test_main.py index 732f07a..4218c19 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -33,15 +33,21 @@ def test_create_vocab(self, mock_create: MagicMock, mock_import: MagicMock) -> N mock_create.assert_called_once_with(mock_import.return_value.vocab_dict) self.assertSuccess(result) + @patch("sqlsynthgen.main.get_settings") @patch("sqlsynthgen.main.import_file") @patch("sqlsynthgen.main.Path") @patch("sqlsynthgen.main.make_table_generators") def test_make_generators( - self, mock_make: MagicMock, mock_path: MagicMock, mock_import: MagicMock + self, + mock_make: MagicMock, + mock_path: MagicMock, + mock_import: MagicMock, + mock_settings: MagicMock, ) -> None: """Test the make-generators sub-command.""" mock_path.return_value.exists.return_value = False mock_make.return_value = "some text" + mock_settings.return_value.src_postges_dsn = "" result = runner.invoke( app, @@ -80,16 +86,39 @@ def test_make_generators_errors_if_file_exists( ) self.assertEqual(1, result.exit_code) + @patch("sqlsynthgen.main.stderr", new_callable=StringIO) + def test_make_generators_errors_if_src_dsn_missing( + self, mock_stderr: MagicMock + ) -> None: + """Test the make-generators sub-command with missing db params.""" + result = runner.invoke( + app, + [ + "make-generators", + ], + catch_exceptions=False, + ) + self.assertEqual( + "Missing source database connection details.\n", mock_stderr.getvalue() + ) + self.assertEqual(1, result.exit_code) + + @patch("sqlsynthgen.main.get_settings") @patch("sqlsynthgen.main.Path") @patch("sqlsynthgen.main.import_file") @patch("sqlsynthgen.main.make_table_generators") def test_make_generators_with_force_enabled( - self, mock_make: MagicMock, mock_import: MagicMock, mock_path: MagicMock + self, + mock_make: MagicMock, + mock_import: MagicMock, + mock_path: MagicMock, + mock_settings: MagicMock, ) -> None: """Tests the make-generators sub-commands overwrite files when instructed.""" mock_path.return_value.exists.return_value = True mock_make.return_value = "make result" + mock_settings.return_value.src_postges_dsn = "" for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): @@ -204,6 +233,24 @@ def test_make_tables_errors_if_file_exists( ) self.assertEqual(1, result.exit_code) + @patch("sqlsynthgen.main.stderr", new_callable=StringIO) + def test_make_tables_errors_if_src_dsn_missing( + self, mock_stderr: MagicMock + ) -> None: + """Test the make-tables sub-command doesn't overwrite.""" + + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, + ) + self.assertEqual( + "Missing source database connection details.\n", mock_stderr.getvalue() + ) + self.assertEqual(1, result.exit_code) + @patch("sqlsynthgen.main.make_tables_file") @patch("sqlsynthgen.main.get_settings") @patch("sqlsynthgen.main.Path") @@ -292,6 +339,28 @@ def test_make_stats_errors_if_file_exists( ) self.assertEqual(1, result.exit_code) + @patch("sqlsynthgen.main.stderr", new_callable=StringIO) + def test_make_stats_errors_if_no_src_dsn( + self, + mock_stderr: MagicMock, + ) -> None: + """Test the make-stats sub-command with missing settings.""" + example_conf_path = "tests/examples/example_config.yaml" + + result = runner.invoke( + app, + [ + "make-stats", + f"--config-file={example_conf_path}", + ], + catch_exceptions=False, + ) + self.assertEqual( + "Missing source database connection details.\n", + mock_stderr.getvalue(), + ) + self.assertEqual(1, result.exit_code) + @patch("sqlsynthgen.main.Path") @patch("sqlsynthgen.main.make_src_stats") @patch("sqlsynthgen.main.get_settings") diff --git a/tests/test_settings.py b/tests/test_settings.py index e90b3ca..a3024c5 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,59 +6,73 @@ class TestSettings(SSGTestCase): """Tests for the Settings class.""" - def test_default_settings(self) -> None: + def test_minimal_settings(self) -> None: """Test the minimal settings.""" + settings = Settings( + # To stop any local .env files influencing the test + _env_file=None, + ) + self.assertIsNone(settings.src_postgres_dsn) + self.assertEqual(5432, settings.src_port) + self.assertEqual(False, settings.src_ssl_required) + + self.assertIsNone(settings.dst_postgres_dsn) + self.assertEqual(5432, settings.dst_port) + self.assertEqual(False, settings.dst_ssl_required) + + def test_maximal_settings(self) -> None: + """Test the full settings.""" settings = Settings( src_host_name="shost", + src_port=1234, src_user_name="suser", src_password="spassword", src_db_name="sdbname", + src_ssl_required=True, dst_host_name="dhost", + dst_port=4321, dst_user_name="duser", dst_password="dpassword", dst_db_name="ddbname", + dst_schema="dschema", + dst_ssl_required=True, # To stop any local .env files influencing the test _env_file=None, ) self.assertEqual( - "postgresql://suser:spassword@shost:5432/sdbname", + "postgresql://suser:spassword@shost:1234/sdbname?sslmode=require", str(settings.src_postgres_dsn), ) - self.assertIsNone(settings.src_schema) - self.assertIsNone(settings.dst_schema) self.assertEqual( - "postgresql://duser:dpassword@dhost:5432/ddbname", + "postgresql://duser:dpassword@dhost:4321/ddbname?sslmode=require", str(settings.dst_postgres_dsn), ) - def test_maximal_settings(self) -> None: - """Test the full settings.""" + def test_typical_settings(self) -> None: + """Test that we can make src and dst Postgres DSNs.""" settings = Settings( src_host_name="shost", - src_port=1234, src_user_name="suser", src_password="spassword", src_db_name="sdbname", - src_ssl_required=True, dst_host_name="dhost", - dst_port=4321, dst_user_name="duser", dst_password="dpassword", dst_db_name="ddbname", - dst_schema="dschema", - dst_ssl_required=True, # To stop any local .env files influencing the test _env_file=None, ) self.assertEqual( - "postgresql://suser:spassword@shost:1234/sdbname?sslmode=require", + "postgresql://suser:spassword@shost:5432/sdbname", str(settings.src_postgres_dsn), ) + self.assertIsNone(settings.src_schema) + self.assertIsNone(settings.dst_schema) self.assertEqual( - "postgresql://duser:dpassword@dhost:4321/ddbname?sslmode=require", + "postgresql://duser:dpassword@dhost:5432/ddbname", str(settings.dst_postgres_dsn), ) diff --git a/tests/utils.py b/tests/utils.py index 770bb82..6f6f4e1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -80,9 +80,3 @@ def assertSuccess(self, result: Any) -> None: # pylint: disable=invalid-name @skipUnless(os.environ.get("REQUIRES_DB") == "1", "Set 'REQUIRES_DB=1' to enable.") class RequiresDBTestCase(SSGTestCase): """A test case that only runs if REQUIRES_DB has been set to 1.""" - - def setUp(self) -> None: - pass - - def tearDown(self) -> None: - pass