From cffd2b232d9298321c34430f0ac829f6b6ebea8f Mon Sep 17 00:00:00 2001 From: lpoulain Date: Thu, 21 Jul 2022 16:42:35 -0400 Subject: [PATCH] Optimize experimental_python_types and add type-mapping tests Instead of checking the type for each row, check the type once for each fetch() call and compute a list of lambdas which are to be applied to the values from each row. A new RowMapperFactory class is created to wrap this behavior. The experimental_python_types flag is now processed in the TrinoQuery class instead of the TrinoResult class. Type mapping tests for each lambda which maps rows to Python types is added. --- tests/integration/test_types_integration.py | 247 ++++++++++++++++++ trino/client.py | 261 +++++++++++++------- 2 files changed, 420 insertions(+), 88 deletions(-) create mode 100644 tests/integration/test_types_integration.py diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py new file mode 100644 index 00000000..5749a820 --- /dev/null +++ b/tests/integration/test_types_integration.py @@ -0,0 +1,247 @@ +import math +import pytest +from decimal import Decimal +import trino + + +@pytest.fixture +def trino_connection(run_trino): + _, host, port = run_trino + + yield trino.dbapi.Connection( + host=host, port=port, user="test", source="test", max_attempts=1 + ) + + +def test_boolean(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS BOOLEAN)", python=None) \ + .add_field(sql="false", python=False) \ + .add_field(sql="true", python=True) \ + .execute() + + +def test_tinyint(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS TINYINT)", python=None) \ + .add_field(sql="CAST(-128 AS TINYINT)", python=-128) \ + .add_field(sql="CAST(42 AS TINYINT)", python=42) \ + .add_field(sql="CAST(127 AS TINYINT)", python=127) \ + .execute() + + +def test_smallint(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS SMALLINT)", python=None) \ + .add_field(sql="CAST(-32768 AS SMALLINT)", python=-32768) \ + .add_field(sql="CAST(42 AS SMALLINT)", python=42) \ + .add_field(sql="CAST(32767 AS SMALLINT)", python=32767) \ + .execute() + + +def test_int(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS INTEGER)", python=None) \ + .add_field(sql="CAST(-2147483648 AS INTEGER)", python=-2147483648) \ + .add_field(sql="CAST(83648 AS INTEGER)", python=83648) \ + .add_field(sql="CAST(2147483647 AS INTEGER)", python=2147483647) \ + .execute() + + +def test_bigint(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS BIGINT)", python=None) \ + .add_field(sql="CAST(-9223372036854775808 AS BIGINT)", python=-9223372036854775808) \ + .add_field(sql="CAST(9223 AS BIGINT)", python=9223) \ + .add_field(sql="CAST(9223372036854775807 AS BIGINT)", python=9223372036854775807) \ + .execute() + + +def test_real(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS REAL)", python=None) \ + .add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \ + .add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \ + .add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \ + .add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \ + .add_field(sql="CAST('Infinity' AS REAL)", python=math.inf) \ + .execute() + + +def test_double(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS DOUBLE)", python=None) \ + .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \ + .add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \ + .add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \ + .add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \ + .add_field(sql="CAST('Infinity' AS DOUBLE)", python=math.inf) \ + .execute() + + +def test_decimal(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS DECIMAL)", python=None) \ + .add_field(sql="CAST(null AS DECIMAL(38,0))", python=None) \ + .add_field(sql="DECIMAL '10.3'", python=Decimal('10.3')) \ + .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,18))", python=Decimal('0.123456789123456789')) \ + .add_field(sql="CAST(null AS DECIMAL(18,18))", python=None) \ + .add_field(sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal('234.1235')) \ + .add_field(sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal('10.3')) \ + .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal('0.12')) \ + .add_field(sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal('0.3123')) \ + .execute() + + +def test_varchar(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="'aaa'", python='aaa') \ + .add_field(sql="U&'Hello winter \2603 !'", python='Hello winter °3 !') \ + .add_field(sql="CAST(null AS VARCHAR)", python=None) \ + .add_field(sql="CAST('bbb' AS VARCHAR(1))", python='b') \ + .add_field(sql="CAST(null AS VARCHAR(1))", python=None) \ + .execute() + + +def test_char(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST('ccc' AS CHAR)", python='c') \ + .add_field(sql="CAST(null AS CHAR)", python=None) \ + .add_field(sql="CAST('ddd' AS CHAR(1))", python='d') \ + .add_field(sql="CAST('😂' AS CHAR(1))", python='😂') \ + .add_field(sql="CAST(null AS CHAR(1))", python=None) \ + .execute() + + +def test_varbinary(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="X'65683F'", python='ZWg/') \ + .add_field(sql="X''", python='') \ + .add_field(sql="CAST('' AS VARBINARY)", python='') \ + .add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \ + .add_field(sql="CAST(null AS VARBINARY)", python=None) \ + .execute() + + +def test_varbinary_failure(trino_connection): + SqlExpectFailureTest(trino_connection) \ + .execute("CAST(42 AS VARBINARY)") + + +def test_json(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST('{}' AS JSON)", python='"{}"') \ + .add_field(sql="CAST('null' AS JSON)", python='"null"') \ + .add_field(sql="CAST(null AS JSON)", python=None) \ + .add_field(sql="CAST('3.14' AS JSON)", python='"3.14"') \ + .add_field(sql="CAST('a string' AS JSON)", python='"a string"') \ + .add_field(sql="CAST('a \" complex '' string :' AS JSON)", python='"a \\" complex \' string :"') \ + .add_field(sql="CAST('[]' AS JSON)", python='"[]"') \ + .execute() + + +def test_interval(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \ + .add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \ + .add_field(sql="INTERVAL '3' MONTH", python='0-3') \ + .add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \ + .add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \ + .execute() + + +def test_array(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ + .add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \ + .execute() + + +def test_map(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \ + .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \ + .execute() + + +def test_row(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \ + .add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \ + .execute() + + +def test_ipaddress(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS IPADDRESS)", python=None) \ + .add_field(sql="IPADDRESS '2001:db8::1'", python='2001:db8::1') \ + .execute() + + +def test_uuid(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS UUID)", python=None) \ + .add_field(sql="UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59'", python='12151fd2-7586-11e9-8f9e-2a86e4085a59') \ + .execute() + + +def test_digest(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS HyperLogLog)", python=None) \ + .add_field(sql="CAST(null AS P4HyperLogLog)", python=None) \ + .add_field(sql="CAST(null AS SetDigest)", python=None) \ + .add_field(sql="CAST(null AS QDigest(BIGINT))", python=None) \ + .add_field(sql="CAST(null AS TDigest)", python=None) \ + .add_field(sql="approx_set(1)", python='AgwBAIADRAA=') \ + .add_field(sql="CAST(approx_set(1) AS P4HyperLogLog)", python='AwwAAAAg' + 'A' * 2730 + '==') \ + .add_field(sql="make_set_digest(1)", python='AQgAAAACCwEAgANEAAAgAAABAAAASsQF+7cDRAABAA==') \ + .add_field(sql="tdigest_agg(1)", + python='AAAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8=') \ + .execute() + + +class SqlTest: + def __init__(self, trino_connection): + self.cur = trino_connection.cursor(experimental_python_types=True) + self.sql_args = [] + self.expected_result = [] + + def add_field(self, sql, python): + self.sql_args.append(sql) + self.expected_result.append(python) + return self + + def execute(self): + sql = 'SELECT ' + ',\n'.join(self.sql_args) + + self.cur.execute(sql) + actual_result = self.cur.fetchall() + self._compare_results(actual_result[0], self.expected_result) + + def _compare_results(self, actual, expected): + assert len(actual) == len(expected) + + for idx, actual_val in enumerate(actual): + expected_val = expected[idx] + if type(actual_val) == float and math.isnan(actual_val) \ + and type(expected_val) == float and math.isnan(expected_val): + continue + + assert actual_val == expected_val + + +class SqlExpectFailureTest: + def __init__(self, trino_connection): + self.cur = trino_connection.cursor(experimental_python_types=True) + + def execute(self, field): + sql = 'SELECT ' + field + + try: + self.cur.execute(sql) + self.cur.fetchall() + success = True + except Exception: + success = False + + assert not success, "Test not expected to succeed" diff --git a/trino/client.py b/trino/client.py index 211ce9c0..fd0eeaeb 100644 --- a/trino/client.py +++ b/trino/client.py @@ -592,11 +592,10 @@ class TrinoResult(object): https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows=None, experimental_python_types: bool = False): + def __init__(self, query, rows=None): self._query = query self._rows = rows or [] self._rownumber = 0 - self._experimental_python_types = experimental_python_types @property def rownumber(self) -> int: @@ -606,7 +605,7 @@ def __iter__(self): # Initial fetch from the first POST request for row in self._rows: self._rownumber += 1 - yield self._map_row(self._experimental_python_types, row, self._query.columns) + yield row self._rows = None # Subsequent fetches from GET requests until next_uri is empty. @@ -615,93 +614,12 @@ def __iter__(self): for row in rows: self._rownumber += 1 logger.debug("row %s", row) - yield self._map_row(self._experimental_python_types, row, self._query.columns) + yield row @property def response_headers(self): return self._query.response_headers - @classmethod - def _map_row(cls, experimental_python_types, row, columns): - if not experimental_python_types: - return row - else: - return cls._map_to_python_types(cls, row, columns) - - @classmethod - def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: - (value, data_type) = item - - if value is None: - return None - - raw_type = data_type["typeSignature"]["rawType"] - - try: - if isinstance(value, list): - if raw_type == "array": - raw_type = { - "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] - } - return [cls._map_to_python_type((array_item, raw_type)) for array_item in value] - if raw_type == "row": - raw_types = map(lambda arg: arg["value"], data_type["typeSignature"]["arguments"]) - return tuple( - cls._map_to_python_type((array_item, raw_type)) - for (array_item, raw_type) in zip(value, raw_types) - ) - return value - if isinstance(value, dict): - raw_key_type = { - "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] - } - raw_value_type = { - "typeSignature": data_type["typeSignature"]["arguments"][1]["value"] - } - return { - cls._map_to_python_type((key, raw_key_type)): - cls._map_to_python_type((value[key], raw_value_type)) - for key in value - } - elif "decimal" in raw_type: - return Decimal(value) - elif raw_type == "double": - if value == 'Infinity': - return INF - elif value == '-Infinity': - return NEGATIVE_INF - elif value == 'NaN': - return NAN - return value - elif raw_type == "date": - return datetime.strptime(value, "%Y-%m-%d").date() - elif raw_type == "timestamp with time zone": - dt, tz = value.rsplit(' ', 1) - if tz.startswith('+') or tz.startswith('-'): - return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z") - return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz)) - elif "timestamp" in raw_type: - return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") - elif "time with time zone" in raw_type: - matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) - assert matches is not None - assert len(matches.groups()) == 4 - if matches.group(2) == '-': - tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) - else: - tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) - return datetime.strptime(matches.group(1), "%H:%M:%S.%f").time().replace(tzinfo=timezone(tz)) - elif "time" in raw_type: - return datetime.strptime(value, "%H:%M:%S.%f").time() - else: - return value - except ValueError as e: - error_str = f"Could not convert '{value}' into the associated python type for '{raw_type}'" - raise trino.exceptions.TrinoDataError(error_str) from e - - def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]: - return list(map(self._map_to_python_type, zip(row, columns))) - class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" @@ -723,9 +641,10 @@ def __init__( self._request = request self._update_type = None self._sql = sql - self._result = TrinoResult(self, experimental_python_types=experimental_python_types) + self._result = TrinoResult(self) self._response_headers = None self._experimental_python_types = experimental_python_types + self._row_mapper: Optional[RowMapper] = None @property def columns(self): @@ -776,12 +695,18 @@ def execute(self, additional_http_headers=None) -> TrinoResult: self._warnings = getattr(status, "warnings", []) if status.next_uri is None: self._finished = True - self._result = TrinoResult(self, status.rows, self._experimental_python_types) + + rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows + + self._result = TrinoResult(self, rows) return self._result def _update_state(self, status): self._stats.update(status.stats) self._update_type = status.update_type + if not self._row_mapper and status.columns: + self._row_mapper = RowMapperFactory().create(columns=status.columns, + experimental_python_types=self._experimental_python_types) if status.columns: self._columns = status.columns @@ -794,7 +719,11 @@ def fetch(self) -> List[List[Any]]: self._response_headers = response.headers if status.next_uri is None: self._finished = True - return status.rows + + if not self._row_mapper: + return [] + + return self._row_mapper.map(status.rows) def cancel(self) -> None: """Cancel the current query""" @@ -857,3 +786,159 @@ def decorated(*args, **kwargs): return decorated return wrapper + + +class NoOpRowMapper: + """ + No-op RowMapper which does not perform any transformation + Used when experimental_python_types is False. + """ + + def map(self, rows): + return rows + + +class RowMapperFactory: + """ + Given the 'columns' result from Trino, generate a list of + lambda functions (one for each column) which will process a data value + and returns a RowMapper instance which will process rows of data + """ + no_op_row_mapper = NoOpRowMapper() + + def create(self, columns, experimental_python_types): + assert columns is not None + + if experimental_python_types: + return RowMapper([self._col_func(column['typeSignature']) for column in columns]) + return RowMapperFactory.no_op_row_mapper + + def _col_func(self, column): + col_type = column['rawType'] + + if col_type == 'array': + return self._array_map_func(column) + elif col_type == 'row': + return self._row_map_func(column) + elif col_type == 'map': + return self._map_map_func(column) + elif col_type.startswith('decimal'): + return lambda val: Decimal(val) + elif col_type.startswith('double') or col_type.startswith('real'): + return self._double_map_func() + elif col_type.startswith('timestamp'): + return self._timestamp_map_func(column, col_type) + elif col_type.startswith('time'): + return self._time_map_func(column, col_type) + elif col_type == 'date': + return lambda val: datetime.strptime(val, '%Y-%m-%d').date() + else: + return lambda val: val + + def _array_map_func(self, column): + element_mapping_func = self._col_func(column['arguments'][0]['value']) + return lambda values: [element_mapping_func(value) for value in values] + + def _row_map_func(self, column): + element_mapping_func = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']] + return lambda values: tuple(element_mapping_func[idx](value) for idx, value in enumerate(values)) + + def _map_map_func(self, column): + key_mapping_func = self._col_func(column['arguments'][0]['value']) + value_mapping_func = self._col_func(column['arguments'][1]['value']) + return lambda values: {key_mapping_func(key): value_mapping_func(value) for key, value in values.items()} + + def _double_map_func(self): + return lambda val: INF if val == 'Infinity' \ + else NEGATIVE_INF if val == '-Infinity' \ + else NAN if val == 'NaN' \ + else float(val) + + def _timestamp_map_func(self, column, col_type): + datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds) + pattern = "%Y-%m-%d %H:%M:%S" + ms_size, ms_to_trim = self._get_number_of_digits(column) + if ms_size > 0: + pattern += ".%f" + + dt_size = datetime_default_size + ms_size - ms_to_trim + dt_tz_offset = datetime_default_size + ms_size + if 'with time zone' in col_type: + + if ms_to_trim > 0: + return lambda val: \ + [datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z') + if tz.startswith('+') or tz.startswith('-') + else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern) + .replace(tzinfo=pytz.timezone(tz)) + for dt, tz in [val.rsplit(' ', 1)]][0] + else: + return lambda val: [datetime.strptime(val, pattern + ' %z') + if tz.startswith('+') or tz.startswith('-') + else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz)) + for dt, tz in [val.rsplit(' ', 1)]][0] + + if ms_to_trim > 0: + return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern) + else: + return lambda val: datetime.strptime(val, pattern) + + def _time_map_func(self, column, col_type): + pattern = "%H:%M:%S" + ms_size, ms_to_trim = self._get_number_of_digits(column) + if ms_size > 0: + pattern += ".%f" + + time_size = 9 + ms_size - ms_to_trim + + if 'with time zone' in col_type: + return lambda val: self._get_time_with_timezome(val, time_size, pattern) + else: + return lambda val: datetime.strptime(val[:time_size], pattern).time() + + def _get_time_with_timezome(self, value, time_size, pattern): + matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) + assert matches is not None + assert len(matches.groups()) == 4 + if matches.group(2) == '-': + tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) + else: + tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) + return datetime.strptime(matches.group(1)[:time_size], pattern).time().replace(tzinfo=timezone(tz)) + + def _get_number_of_digits(self, column): + args = column['arguments'] + if len(args) == 0: + return 3, 0 + ms_size = column['arguments'][0]['value'] + if ms_size == 0: + return -1, 0 + ms_to_trim = ms_size - min(ms_size, 6) + return ms_size, ms_to_trim + + +class RowMapper: + """ + Maps a row of data given a list of mapping functions + """ + + def __init__(self, columns=[]): + self.columns = columns + + def map(self, rows): + if len(self.columns) == 0: + return rows + return [self._map_row(row) for row in rows] + + def _map_row(self, row): + return [self._map_value(value, self.columns[idx]) for idx, value in enumerate(row)] + + def _map_value(self, value, col_mapping_func): + if value is None: + return None + + try: + return col_mapping_func(value) + except ValueError as e: + error_str = f"Could not convert '{value}' into the associated python type" + raise trino.exceptions.TrinoDataError(error_str) from e