Skip to content

Commit

Permalink
feat: Support conforming singer property names to target identifier c…
Browse files Browse the repository at this point in the history
…onstraints in SQL sinks (#1039)

Co-authored-by: Edgar R. M. <[email protected]>
Co-authored-by: Edgar R. M <[email protected]>
  • Loading branch information
Ken Payne and edgarrmondragon authored Oct 20, 2022
1 parent 2721dc5 commit 937acf3
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 11 deletions.
3 changes: 3 additions & 0 deletions samples/sample_tap_hostile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""A sample tap for testing SQL target property name transformations."""

from .hostile_tap import SampleTapHostile
40 changes: 40 additions & 0 deletions samples/sample_tap_hostile/hostile_streams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import random
import string
from typing import Iterable

from singer_sdk import typing as th
from singer_sdk.streams import Stream


class HostilePropertyNamesStream(Stream):
"""
A stream with property names that are not compatible as unescaped identifiers
in common DBMS systems.
"""

name = "hostile_property_names_stream"
schema = th.PropertiesList(
th.Property("name with spaces", th.StringType),
th.Property("NameIsCamelCase", th.StringType),
th.Property("name-with-dashes", th.StringType),
th.Property("Name-with-Dashes-and-Mixed-cases", th.StringType),
th.Property("5name_starts_with_number", th.StringType),
th.Property("6name_starts_with_number", th.StringType),
th.Property("7name_starts_with_number", th.StringType),
th.Property("name_with_emoji_😈", th.StringType),
).to_dict()

@staticmethod
def get_random_lowercase_string():
return "".join(random.choice(string.ascii_lowercase) for _ in range(10))

def get_records(self, context: dict | None) -> Iterable[dict | tuple[dict, dict]]:
return (
{
key: self.get_random_lowercase_string()
for key in self.schema["properties"].keys()
}
for _ in range(10)
)
24 changes: 24 additions & 0 deletions samples/sample_tap_hostile/hostile_tap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""A sample tap for testing SQL target property name transformations."""

from typing import List

from samples.sample_tap_hostile.hostile_streams import HostilePropertyNamesStream
from singer_sdk import Stream, Tap
from singer_sdk.typing import PropertiesList


class SampleTapHostile(Tap):
"""Sample tap for for testing SQL target property name transformations."""

name: str = "sample-tap-hostile"
config_jsonschema = PropertiesList().to_dict()

def discover_streams(self) -> List[Stream]:
"""Return a list of discovered streams."""
return [
HostilePropertyNamesStream(tap=self),
]


if __name__ == "__main__":
SampleTapHostile.cli()
7 changes: 7 additions & 0 deletions singer_sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ class TapStreamConnectionFailure(Exception):

class TooManyRecordsException(Exception):
"""Exception to raise when query returns more records than max_records."""


class ConformedNameClashException(Exception):
"""Raised when name conforming produces clashes.
e.g. two columns conformed to the same name
"""
42 changes: 42 additions & 0 deletions singer_sdk/helpers/_conformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Helper functions for conforming identifiers."""
import re
from string import ascii_lowercase, digits


def snakecase(string: str) -> str:
"""Convert string into snake case.
Args:
string: String to convert.
Returns:
string: Snake cased string.
"""
string = re.sub(r"[\-\.\s]", "_", string)
string = (
string[0].lower()
+ re.sub(
r"[A-Z]", lambda matched: "_" + str(matched.group(0).lower()), string[1:]
)
if string
else string
)
return re.sub(r"_{2,}", "_", string).rstrip("_")


def replace_leading_digit(string: str) -> str:
"""Replace leading numeric character with equivalent letter.
Args:
string: String to process.
Returns:
A modified string if original starts with a number,
else the unmodified original.
"""
if string[0] in digits:
letters = list(ascii_lowercase)
numbers = [int(d) for d in digits]
digit_map = {n: letters[n] for n in numbers}
return digit_map[int(string[0])] + string[1:]
return string
11 changes: 10 additions & 1 deletion singer_sdk/sinks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.latest_state: dict | None = None
self._draining_state: dict | None = None
self.drained_state: dict | None = None
self.key_properties = key_properties or []
self._key_properties = key_properties or []

# Tally counters
self._total_records_written: int = 0
Expand Down Expand Up @@ -202,6 +202,15 @@ def datetime_error_treatment(self) -> DatetimeErrorTreatmentEnum:
"""
return DatetimeErrorTreatmentEnum.ERROR

@property
def key_properties(self) -> list[str]:
"""Return key properties.
Returns:
A list of stream key properties.
"""
return self._key_properties

# Record processing

def _add_sdc_metadata_to_record(
Expand Down
116 changes: 106 additions & 10 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Sink classes load data to SQL targets."""

import re
from collections import defaultdict
from copy import copy
from textwrap import dedent
from typing import Any, Dict, Iterable, List, Optional, Type, Union

Expand All @@ -8,6 +11,8 @@
from sqlalchemy.sql import Executable
from sqlalchemy.sql.expression import bindparam

from singer_sdk.exceptions import ConformedNameClashException
from singer_sdk.helpers._conformers import replace_leading_digit, snakecase
from singer_sdk.plugin_base import PluginBase
from singer_sdk.sinks.batch import BatchSink
from singer_sdk.streams import SQLConnector
Expand Down Expand Up @@ -67,7 +72,8 @@ def table_name(self) -> str:
The target table name.
"""
parts = self.stream_name.split("-")
return self.stream_name if len(parts) == 1 else parts[-1]
table = self.stream_name if len(parts) == 1 else parts[-1]
return self.conform_name(table, "table")

@property
def schema_name(self) -> Optional[str]:
Expand All @@ -80,7 +86,7 @@ def schema_name(self) -> Optional[str]:
if len(parts) in {2, 3}:
# Stream name is a two-part or three-part identifier.
# Use the second-to-last part as the schema name.
return parts[-2]
return self.conform_name(parts[-2], "schema")

# Schema name not detected.
return None
Expand Down Expand Up @@ -118,6 +124,86 @@ def full_schema_name(self) -> str:
schema_name=self.schema_name, db_name=self.database_name
)

def conform_name(self, name: str, object_type: Optional[str] = None) -> str:
"""Conform a stream property name to one suitable for the target system.
Transforms names to snake case by default, applicable to most common DBMSs'.
Developers may override this method to apply custom transformations
to database/schema/table/column names.
Args:
name: Property name.
object_type: One of ``database``, ``schema``, ``table`` or ``column``.
Returns:
The name transformed to snake case.
"""
# strip non-alphanumeric characters, keeping - . _ and spaces
name = re.sub(r"[^a-zA-Z0-9_\-\.\s]", "", name)
# convert to snakecase
name = snakecase(name)
# replace leading digit
return replace_leading_digit(name)

@staticmethod
def _check_conformed_names_not_duplicated(
conformed_property_names: Dict[str, str]
) -> None:
"""Check if conformed names produce duplicate keys.
Args:
conformed_property_names: A name:conformed_name dict map.
Raises:
ConformedNameClashException: if duplicates found.
"""
# group: {'_a': ['1_a'], 'abc': ['aBc', 'abC']}
grouped = defaultdict(list)
for k, v in conformed_property_names.items():
grouped[v].append(k)

# filter
duplicates = list(filter(lambda p: len(p[1]) > 1, grouped.items()))
if duplicates:
raise ConformedNameClashException(
"Duplicate stream properties produced when "
+ f"conforming property names: {duplicates}"
)

def conform_schema(self, schema: dict) -> dict:
"""Return schema dictionary with property names conformed.
Args:
schema: JSON schema dictionary.
Returns:
A schema dictionary with the property names conformed.
"""
conformed_schema = copy(schema)
conformed_property_names = {
key: self.conform_name(key) for key in conformed_schema["properties"].keys()
}
self._check_conformed_names_not_duplicated(conformed_property_names)
conformed_schema["properties"] = {
conformed_property_names[key]: value
for key, value in conformed_schema["properties"].items()
}
return conformed_schema

def conform_record(self, record: dict) -> dict:
"""Return record dictionary with property names conformed.
Args:
record: Dictionary representing a single record.
Returns:
New record dictionary with conformed column names.
"""
conformed_property_names = {key: self.conform_name(key) for key in record}
self._check_conformed_names_not_duplicated(conformed_property_names)
return {conformed_property_names[key]: value for key, value in record.items()}

def setup(self) -> None:
"""Set up Sink.
Expand All @@ -128,11 +214,20 @@ def setup(self) -> None:
self.connector.prepare_schema(self.schema_name)
self.connector.prepare_table(
full_table_name=self.full_table_name,
schema=self.schema,
schema=self.conform_schema(self.schema),
primary_keys=self.key_properties,
as_temp_table=False,
)

@property
def key_properties(self) -> List[str]:
"""Return key properties, conformed to target system naming requirements.
Returns:
A list of key properties, conformed with `self.conform_name()`
"""
return [self.conform_name(key, "column") for key in super().key_properties]

def process_batch(self, context: dict) -> None:
"""Process a batch with the given batch context.
Expand Down Expand Up @@ -164,15 +259,14 @@ def generate_insert_statement(
Returns:
An insert statement.
"""
property_names = list(schema["properties"].keys())
property_names = list(self.conform_schema(schema)["properties"].keys())
statement = dedent(
f"""\
INSERT INTO {full_table_name}
({", ".join(property_names)})
VALUES ({", ".join([f":{name}" for name in property_names])})
"""
)

return statement.rstrip()

def bulk_insert_records(
Expand Down Expand Up @@ -203,12 +297,14 @@ def bulk_insert_records(
if isinstance(insert_sql, str):
insert_sql = sqlalchemy.text(insert_sql)

conformed_records = (
[self.conform_record(record) for record in records]
if isinstance(records, list)
else (self.conform_record(record) for record in records)
)
self.logger.info("Inserting with SQL: %s", insert_sql)
self.connector.connection.execute(insert_sql, records)
if isinstance(records, list):
return len(records) # If list, we can quickly return record count.

return None # Unknown record count.
self.connector.connection.execute(insert_sql, conformed_records)
return len(conformed_records) if isinstance(conformed_records, list) else None

def merge_upsert_from_table(
self, target_table_name: str, from_table_name: str, join_keys: List[str]
Expand Down
39 changes: 39 additions & 0 deletions tests/core/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import sqlalchemy

from samples.sample_tap_hostile import SampleTapHostile
from samples.sample_tap_sqlite import SQLiteConnector, SQLiteTap
from samples.sample_target_csv.csv_target import SampleTargetCSV
from samples.sample_target_sqlite import SQLiteSink, SQLiteTarget
Expand Down Expand Up @@ -569,3 +570,41 @@ def test_sqlite_generate_insert_statement(
sink.schema,
)
assert dml == expected_dml


def test_hostile_to_sqlite(
sqlite_sample_target: SQLTarget, sqlite_target_test_config: dict
):
tap = SampleTapHostile()
tap_to_target_sync_test(tap, sqlite_sample_target)
# check if stream table was created
db = sqlite3.connect(sqlite_target_test_config["path_to_db"])
cursor = db.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [res[0] for res in cursor.fetchall()]
assert "hostile_property_names_stream" in tables
# check if columns were conformed
cursor.execute(
dedent(
"""
SELECT
p.name as columnName
FROM sqlite_master m
left outer join pragma_table_info((m.name)) p
on m.name <> p.name
where m.name = 'hostile_property_names_stream'
;
"""
)
)
columns = {res[0] for res in cursor.fetchall()}
assert columns == {
"name_with_spaces",
"name_is_camel_case",
"name_with_dashes",
"name_with_dashes_and_mixed_cases",
"gname_starts_with_number",
"fname_starts_with_number",
"hname_starts_with_number",
"name_with_emoji",
}

0 comments on commit 937acf3

Please sign in to comment.