diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 042691b3b..2bb102237 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -630,21 +630,10 @@ def _create_empty_column( if not self.allow_column_add: raise NotImplementedError("Adding columns is not supported.") - create_column_clause = sqlalchemy.schema.CreateColumn( - sqlalchemy.Column( - column_name, - sql_type, - ) - ) - self.connection.execute( - sqlalchemy.DDL( - "ALTER TABLE %(table)s ADD COLUMN %(create_column)s", - { - "table": full_table_name, - "create_column": create_column_clause, - }, - ) + column_add_ddl = self.get_column_add_ddl( + table_name=full_table_name, column_name=column_name, column_type=sql_type ) + self.connection.execute(column_add_ddl) def prepare_schema(self, schema_name: str) -> None: """Create the target database schema. @@ -729,10 +718,10 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N if not self.allow_column_rename: raise NotImplementedError("Renaming columns is not supported.") - self.connection.execute( - f"ALTER TABLE {full_table_name} " - f'RENAME COLUMN "{old_name}" to "{new_name}"' + column_rename_ddl = self.get_column_rename_ddl( + table_name=full_table_name, column_name=old_name, new_column_name=new_name ) + self.connection.execute(column_rename_ddl) def merge_sql_types( self, sql_types: list[sqlalchemy.types.TypeEngine] @@ -871,6 +860,87 @@ def _get_column_type( return cast(sqlalchemy.types.TypeEngine, column.type) + @staticmethod + def get_column_add_ddl( + table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine + ) -> sqlalchemy.DDL: + """Get the create column DDL statement. + + Override this if your database uses a different syntax for creating columns. + + Args: + table_name: Fully qualified table name of column to alter. + column_name: Column name to create. + column_type: New column sqlalchemy type. + + Returns: + A sqlalchemy DDL instance. + """ + create_column_clause = sqlalchemy.schema.CreateColumn( + sqlalchemy.Column( + column_name, + column_type, + ) + ) + return sqlalchemy.DDL( + "ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s", + { + "table_name": table_name, + "create_column_clause": create_column_clause, + }, + ) + + @staticmethod + def get_column_rename_ddl( + table_name: str, column_name: str, new_column_name: str + ) -> sqlalchemy.DDL: + """Get the create column DDL statement. + + Override this if your database uses a different syntax for renaming columns. + + Args: + table_name: Fully qualified table name of column to alter. + column_name: Existing column name. + new_column_name: New column name. + + Returns: + A sqlalchemy DDL instance. + """ + return sqlalchemy.DDL( + "ALTER TABLE %(table_name)s " + "RENAME COLUMN %(column_name)s to %(new_column_name)s", + { + "table_name": table_name, + "column_name": column_name, + "new_column_name": new_column_name, + }, + ) + + @staticmethod + def get_column_alter_ddl( + table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine + ) -> sqlalchemy.DDL: + """Get the alter column DDL statement. + + Override this if your database uses a different syntax for altering columns. + + Args: + table_name: Fully qualified table name of column to alter. + column_name: Column name to alter. + column_type: New column type string. + + Returns: + A sqlalchemy DDL instance. + """ + return sqlalchemy.DDL( + "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)", + { + "table_name": table_name, + "column_name": column_name, + "column_type": column_type, + }, + ) + def _adapt_column_type( self, full_table_name: str, @@ -912,13 +982,9 @@ def _adapt_column_type( f"from '{current_type}' to '{compatible_sql_type}'." ) - self.connection.execute( - sqlalchemy.DDL( - "ALTER TABLE %(table)s ALTER COLUMN %(col_name)s (%(col_type)s)", - { - "table": full_table_name, - "col_name": column_name, - "col_type": compatible_sql_type, - }, - ) + alter_column_ddl = self.get_column_alter_ddl( + table_name=full_table_name, + column_name=column_name, + column_type=compatible_sql_type, ) + self.connection.execute(alter_column_ddl) diff --git a/tests/core/test_connector_sql.py b/tests/core/test_connector_sql.py new file mode 100644 index 000000000..518b17bdf --- /dev/null +++ b/tests/core/test_connector_sql.py @@ -0,0 +1,93 @@ +import pytest +import sqlalchemy +from sqlalchemy.dialects import sqlite + +from singer_sdk.connectors import SQLConnector + + +def stringify(in_dict): + return {k: str(v) for k, v in in_dict.items()} + + +class TestConnectorSQL: + """Test the SQLConnector class.""" + + @pytest.fixture() + def connector(self): + return SQLConnector() + + @pytest.mark.parametrize( + "method_name,kwargs,context,unrendered_statement,rendered_statement", + [ + ( + "get_column_add_ddl", + { + "table_name": "full.table.name", + "column_name": "column_name", + "column_type": sqlalchemy.types.Text(), + }, + { + "table_name": "full.table.name", + "create_column_clause": sqlalchemy.schema.CreateColumn( + sqlalchemy.Column( + "column_name", + sqlalchemy.types.Text(), + ) + ), + }, + "ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s", + "ALTER TABLE full.table.name ADD COLUMN column_name TEXT", + ), + ( + "get_column_rename_ddl", + { + "table_name": "full.table.name", + "column_name": "old_name", + "new_column_name": "new_name", + }, + { + "table_name": "full.table.name", + "column_name": "old_name", + "new_column_name": "new_name", + }, + "ALTER TABLE %(table_name)s RENAME COLUMN %(column_name)s to %(new_column_name)s", + "ALTER TABLE full.table.name RENAME COLUMN old_name to new_name", + ), + ( + "get_column_alter_ddl", + { + "table_name": "full.table.name", + "column_name": "column_name", + "column_type": sqlalchemy.types.String(), + }, + { + "table_name": "full.table.name", + "column_name": "column_name", + "column_type": sqlalchemy.types.String(), + }, + "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)", + "ALTER TABLE full.table.name ALTER COLUMN column_name (VARCHAR)", + ), + ], + ) + def test_get_column_ddl( + self, + connector, + method_name, + kwargs, + context, + unrendered_statement, + rendered_statement, + ): + method = getattr(connector, method_name) + column_ddl = method(**kwargs) + + assert stringify(column_ddl.context) == stringify(context) + assert column_ddl.statement == unrendered_statement + + statement = str( + column_ddl.compile( + dialect=sqlite.dialect(), compile_kwargs={"literal_binds": True} + ) + ) + assert statement == rendered_statement