Skip to content

Commit

Permalink
refactor: make DDL overridable for column ADD, ALTER, and `RENAME…
Browse files Browse the repository at this point in the history
…` operations (#1114)

Co-authored-by: Edgar R. M <[email protected]>
  • Loading branch information
Ken Payne and edgarrmondragon authored Nov 9, 2022
1 parent 786fd95 commit 078629d
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 26 deletions.
118 changes: 92 additions & 26 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,21 +630,10 @@ def _create_empty_column(
if not self.allow_column_add:
raise NotImplementedError("Adding columns is not supported.")

create_column_clause = sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
column_name,
sql_type,
)
)
self.connection.execute(
sqlalchemy.DDL(
"ALTER TABLE %(table)s ADD COLUMN %(create_column)s",
{
"table": full_table_name,
"create_column": create_column_clause,
},
)
column_add_ddl = self.get_column_add_ddl(
table_name=full_table_name, column_name=column_name, column_type=sql_type
)
self.connection.execute(column_add_ddl)

def prepare_schema(self, schema_name: str) -> None:
"""Create the target database schema.
Expand Down Expand Up @@ -729,10 +718,10 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N
if not self.allow_column_rename:
raise NotImplementedError("Renaming columns is not supported.")

self.connection.execute(
f"ALTER TABLE {full_table_name} "
f'RENAME COLUMN "{old_name}" to "{new_name}"'
column_rename_ddl = self.get_column_rename_ddl(
table_name=full_table_name, column_name=old_name, new_column_name=new_name
)
self.connection.execute(column_rename_ddl)

def merge_sql_types(
self, sql_types: list[sqlalchemy.types.TypeEngine]
Expand Down Expand Up @@ -871,6 +860,87 @@ def _get_column_type(

return cast(sqlalchemy.types.TypeEngine, column.type)

@staticmethod
def get_column_add_ddl(
table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine
) -> sqlalchemy.DDL:
"""Get the create column DDL statement.
Override this if your database uses a different syntax for creating columns.
Args:
table_name: Fully qualified table name of column to alter.
column_name: Column name to create.
column_type: New column sqlalchemy type.
Returns:
A sqlalchemy DDL instance.
"""
create_column_clause = sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
column_name,
column_type,
)
)
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s",
{
"table_name": table_name,
"create_column_clause": create_column_clause,
},
)

@staticmethod
def get_column_rename_ddl(
table_name: str, column_name: str, new_column_name: str
) -> sqlalchemy.DDL:
"""Get the create column DDL statement.
Override this if your database uses a different syntax for renaming columns.
Args:
table_name: Fully qualified table name of column to alter.
column_name: Existing column name.
new_column_name: New column name.
Returns:
A sqlalchemy DDL instance.
"""
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s "
"RENAME COLUMN %(column_name)s to %(new_column_name)s",
{
"table_name": table_name,
"column_name": column_name,
"new_column_name": new_column_name,
},
)

@staticmethod
def get_column_alter_ddl(
table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine
) -> sqlalchemy.DDL:
"""Get the alter column DDL statement.
Override this if your database uses a different syntax for altering columns.
Args:
table_name: Fully qualified table name of column to alter.
column_name: Column name to alter.
column_type: New column type string.
Returns:
A sqlalchemy DDL instance.
"""
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)",
{
"table_name": table_name,
"column_name": column_name,
"column_type": column_type,
},
)

def _adapt_column_type(
self,
full_table_name: str,
Expand Down Expand Up @@ -912,13 +982,9 @@ def _adapt_column_type(
f"from '{current_type}' to '{compatible_sql_type}'."
)

self.connection.execute(
sqlalchemy.DDL(
"ALTER TABLE %(table)s ALTER COLUMN %(col_name)s (%(col_type)s)",
{
"table": full_table_name,
"col_name": column_name,
"col_type": compatible_sql_type,
},
)
alter_column_ddl = self.get_column_alter_ddl(
table_name=full_table_name,
column_name=column_name,
column_type=compatible_sql_type,
)
self.connection.execute(alter_column_ddl)
93 changes: 93 additions & 0 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import sqlalchemy
from sqlalchemy.dialects import sqlite

from singer_sdk.connectors import SQLConnector


def stringify(in_dict):
return {k: str(v) for k, v in in_dict.items()}


class TestConnectorSQL:
"""Test the SQLConnector class."""

@pytest.fixture()
def connector(self):
return SQLConnector()

@pytest.mark.parametrize(
"method_name,kwargs,context,unrendered_statement,rendered_statement",
[
(
"get_column_add_ddl",
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.Text(),
},
{
"table_name": "full.table.name",
"create_column_clause": sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
"column_name",
sqlalchemy.types.Text(),
)
),
},
"ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s",
"ALTER TABLE full.table.name ADD COLUMN column_name TEXT",
),
(
"get_column_rename_ddl",
{
"table_name": "full.table.name",
"column_name": "old_name",
"new_column_name": "new_name",
},
{
"table_name": "full.table.name",
"column_name": "old_name",
"new_column_name": "new_name",
},
"ALTER TABLE %(table_name)s RENAME COLUMN %(column_name)s to %(new_column_name)s",
"ALTER TABLE full.table.name RENAME COLUMN old_name to new_name",
),
(
"get_column_alter_ddl",
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.String(),
},
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.String(),
},
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)",
"ALTER TABLE full.table.name ALTER COLUMN column_name (VARCHAR)",
),
],
)
def test_get_column_ddl(
self,
connector,
method_name,
kwargs,
context,
unrendered_statement,
rendered_statement,
):
method = getattr(connector, method_name)
column_ddl = method(**kwargs)

assert stringify(column_ddl.context) == stringify(context)
assert column_ddl.statement == unrendered_statement

statement = str(
column_ddl.compile(
dialect=sqlite.dialect(), compile_kwargs={"literal_binds": True}
)
)
assert statement == rendered_statement

0 comments on commit 078629d

Please sign in to comment.