From 14b61550d82c4f370fd77b6c796cf0494c9378f8 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 29 May 2023 11:23:07 +0300 Subject: [PATCH 1/6] fix(trino,mysql): handle string typed decimal results --- superset/db_engine_specs/base.py | 25 +++++++++++- superset/db_engine_specs/mysql.py | 18 +++++---- superset/db_engine_specs/trino.py | 33 +++++++++++----- .../unit_tests/db_engine_specs/test_mysql.py | 37 +++++++++++++++++- .../unit_tests/db_engine_specs/test_trino.py | 39 ++++++++++++++++++- 5 files changed, 132 insertions(+), 20 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a7ff8622722c1..dc9fe11d491b0 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -314,6 +314,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # engine-specific type mappings to check prior to the defaults column_type_mappings: Tuple[ColumnTypeMapping, ...] = () + # type-specific functions to mutate values received from the database. + # Needed on certain databases that return values in an unexpected format + column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {} + # Does database support join-free timeslot grouping time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT @@ -737,7 +741,26 @@ def fetch_data( try: if cls.limit_method == LimitMethod.FETCH_MANY and limit: return cursor.fetchmany(limit) - return cursor.fetchall() + data = cursor.fetchall() + column_type_mutators = { + row[0]: func + for row in cursor.description + if ( + func := cls.column_type_mutators.get( + type(cls.get_sqla_column_type(row[1])) + ) + ) + } + if column_type_mutators: + indexes = {row[0]: idx for idx, row in enumerate(cursor.description)} + for row_idx, row in enumerate(data): + new_row = list(row) + for col, func in column_type_mutators.items(): + col_idx = indexes[col] + new_row[col_idx] = func(row[col_idx]) + data[row_idx] = tuple(new_row) + + return data except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 6258f6b21a4c6..faf04054016a7 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple +from decimal import Decimal +from typing import Any, Callable, Optional, Pattern from urllib import parse from flask_babel import gettext as __ @@ -123,6 +124,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): GenericDataType.STRING, ), ) + column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = { + DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val + } _time_grain_expressions = { None: "{col}", @@ -143,9 +147,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): "INTERVAL 1 DAY)) - 1 DAY))", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -186,7 +190,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -201,10 +205,10 @@ def convert_dttm( def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: uri, new_connect_args = super().adjust_engine_params( uri, connect_args, @@ -221,7 +225,7 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 0fa4d05cbce0d..c59f6c3ab3ddd 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,12 +17,14 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from decimal import Decimal +from typing import Any, Callable, Optional, Type, TYPE_CHECKING import simplejson as json from flask import current_app from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session +from sqlalchemy.types import DECIMAL, TypeEngine from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe @@ -48,13 +50,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): engine_name = "Trino" allows_alias_to_source_column = False + column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = { + DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val + } + @classmethod def extra_table_metadata( cls, database: Database, table_name: str, schema_name: Optional[str], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: metadata = {} if indexes := database.get_indexes(table_name, schema_name): @@ -95,7 +101,7 @@ def extra_table_metadata( @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, username: Optional[str], ) -> None: @@ -131,7 +137,7 @@ def get_url_for_impersonation( return url @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod @@ -199,7 +205,7 @@ def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: return True @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -207,9 +213,9 @@ def get_extra_params(database: Database) -> Dict[str, Any]: :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("source", USER_AGENT) @@ -222,7 +228,7 @@ def get_extra_params(database: Database) -> Dict[str, Any]: @staticmethod def update_params_from_encrypted_extra( database: Database, - params: Dict[str, Any], + params: dict[str, Any], ) -> None: if not database.encrypted_extra: return @@ -262,10 +268,17 @@ def update_params_from_encrypted_extra( raise ex @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[Type[Exception], Type[Exception]]: # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions return { requests_exceptions.ConnectionError: SupersetDBAPIConnectionError, } + + @classmethod + def fetch_data( + cls, cursor: Any, limit: Optional[int] = None + ) -> list[tuple[Any, ...]]: + data = super().fetch_data(cursor, limit) + return data diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 07ce6838fc20b..33965fe880b40 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,7 +16,8 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Tuple, Type +from decimal import Decimal +from typing import Any, Dict, Optional, Type from unittest.mock import Mock, patch import pytest @@ -220,3 +221,37 @@ def test_get_schema_from_engine_params() -> None: ) == "db1" ) + + +@pytest.mark.parametrize( + "data,description,expected_result", + [ + ( + [["1.23456", "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(1.23456, "abc")], + ), + ( + [[Decimal("1.23456"), "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(1.23456, "abc")], + ), + ( + [["1.23456", "abc"]], + [("dec", "varchar(255)"), ("str", "varchar(3)")], + [["1.23456", "abc"]], + ), + ], +) +def test_column_type_mutator( + data: list[tuple[Any, ...]], + description: list[Any], + expected_result: list[tuple[Any, ...]], +): + from superset.db_engine_specs.trino import TrinoEngineSpec as spec + + mock_cursor = Mock() + mock_cursor.fetchall.return_value = data + mock_cursor.description = description + + assert spec.fetch_data(mock_cursor) == expected_result diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 0ea296a075e71..99731e0fcac0e 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +from __future__ import annotations + import json from datetime import datetime -from typing import Any, Dict, Optional, Type +from decimal import Decimal +from typing import Any, Dict, Optional, Type, Union from unittest.mock import Mock, patch import pandas as pd @@ -366,3 +369,37 @@ def test_handle_cursor_early_cancel( assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id else: assert cancel_query_mock.call_args is None + + +@pytest.mark.parametrize( + "data,description,expected_result", + [ + ( + [["1.23456", "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(1.23456, "abc")], + ), + ( + [[Decimal("1.23456"), "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(1.23456, "abc")], + ), + ( + [["1.23456", "abc"]], + [("dec", "varchar(255)"), ("str", "varchar(3)")], + [["1.23456", "abc"]], + ), + ], +) +def test_column_type_mutator( + data: list[Union[tuple[Any, ...], list[Any]]], + description: list[Any], + expected_result: list[Union[tuple[Any, ...], list[Any]]], +): + from superset.db_engine_specs.trino import TrinoEngineSpec as spec + + mock_cursor = Mock() + mock_cursor.fetchall.return_value = data + mock_cursor.description = description + + assert spec.fetch_data(mock_cursor) == expected_result From f5a2f73eb015fedec02f8908adc738f5c7d127f2 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 29 May 2023 12:27:33 +0300 Subject: [PATCH 2/6] fix postgres edge case --- superset/db_engine_specs/base.py | 2 +- superset/db_engine_specs/mysql.py | 2 +- superset/db_engine_specs/trino.py | 9 +-------- tests/unit_tests/db_engine_specs/test_mysql.py | 4 ++-- tests/unit_tests/db_engine_specs/test_trino.py | 4 ++-- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dc9fe11d491b0..2795f0aa0dfe1 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -747,7 +747,7 @@ def fetch_data( for row in cursor.description if ( func := cls.column_type_mutators.get( - type(cls.get_sqla_column_type(row[1])) + type(cls.get_sqla_column_type(cls.get_datatype(row[1]))) ) ) } diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index faf04054016a7..cd0a5c99a680b 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -125,7 +125,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): ), ) column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = { - DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val + DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val } _time_grain_expressions = { diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index c59f6c3ab3ddd..8228b3f0c87f5 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -51,7 +51,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): allows_alias_to_source_column = False column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = { - DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val + DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val } @classmethod @@ -275,10 +275,3 @@ def get_dbapi_exception_mapping(cls) -> dict[Type[Exception], Type[Exception]]: return { requests_exceptions.ConnectionError: SupersetDBAPIConnectionError, } - - @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> list[tuple[Any, ...]]: - data = super().fetch_data(cursor, limit) - return data diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 33965fe880b40..dad00034af05d 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -229,12 +229,12 @@ def test_get_schema_from_engine_params() -> None: ( [["1.23456", "abc"]], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(1.23456, "abc")], + [(Decimal("1.23456"), "abc")], ), ( [[Decimal("1.23456"), "abc"]], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(1.23456, "abc")], + [(Decimal("1.23456"), "abc")], ), ( [["1.23456", "abc"]], diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 99731e0fcac0e..7ddbd4f1f9f4a 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -377,12 +377,12 @@ def test_handle_cursor_early_cancel( ( [["1.23456", "abc"]], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(1.23456, "abc")], + [(Decimal("1.23456"), "abc")], ), ( [[Decimal("1.23456"), "abc"]], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(1.23456, "abc")], + [(Decimal("1.23456"), "abc")], ), ( [["1.23456", "abc"]], From 1249654b388f7a71eca8db29d0bf873a307fbfb9 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 29 May 2023 13:13:08 +0300 Subject: [PATCH 3/6] handle missing decription gracefully --- superset/db_engine_specs/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2795f0aa0dfe1..3767cdf541597 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -742,9 +742,10 @@ def fetch_data( if cls.limit_method == LimitMethod.FETCH_MANY and limit: return cursor.fetchmany(limit) data = cursor.fetchall() + description = cursor.description or [] column_type_mutators = { row[0]: func - for row in cursor.description + for row in description if ( func := cls.column_type_mutators.get( type(cls.get_sqla_column_type(cls.get_datatype(row[1]))) @@ -752,7 +753,7 @@ def fetch_data( ) } if column_type_mutators: - indexes = {row[0]: idx for idx, row in enumerate(cursor.description)} + indexes = {row[0]: idx for idx, row in enumerate(description)} for row_idx, row in enumerate(data): new_row = list(row) for col, func in column_type_mutators.items(): From 261baf432f6836d599b78b44f8653fc3c52224d2 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 30 May 2023 10:50:07 +0300 Subject: [PATCH 4/6] add comment + a unit test for None values --- superset/db_engine_specs/base.py | 9 ++++++--- tests/unit_tests/db_engine_specs/test_mysql.py | 5 +++++ tests/unit_tests/db_engine_specs/test_trino.py | 5 +++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3767cdf541597..37e10e27419b8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -743,7 +743,10 @@ def fetch_data( return cursor.fetchmany(limit) data = cursor.fetchall() description = cursor.description or [] - column_type_mutators = { + # Create a mapping between column name and a mutator function to normalize + # values with. The first two items in the description row are + # the column name and type. + column_mutators = { row[0]: func for row in description if ( @@ -752,11 +755,11 @@ def fetch_data( ) ) } - if column_type_mutators: + if column_mutators: indexes = {row[0]: idx for idx, row in enumerate(description)} for row_idx, row in enumerate(data): new_row = list(row) - for col, func in column_type_mutators.items(): + for col, func in column_mutators.items(): col_idx = indexes[col] new_row[col_idx] = func(row[col_idx]) data[row_idx] = tuple(new_row) diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index dad00034af05d..7e593e04c8f42 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -236,6 +236,11 @@ def test_get_schema_from_engine_params() -> None: [("dec", "decimal(12,6)"), ("str", "varchar(3)")], [(Decimal("1.23456"), "abc")], ), + ( + [[None, "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(None, "abc")], + ), ( [["1.23456", "abc"]], [("dec", "varchar(255)"), ("str", "varchar(3)")], diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 7ddbd4f1f9f4a..d2e75a0842292 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -384,6 +384,11 @@ def test_handle_cursor_early_cancel( [("dec", "decimal(12,6)"), ("str", "varchar(3)")], [(Decimal("1.23456"), "abc")], ), + ( + [[None, "abc"]], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(None, "abc")], + ), ( [["1.23456", "abc"]], [("dec", "varchar(255)"), ("str", "varchar(3)")], From 0f1fd48b9ce6320ec8d6f54120d7e0fb016c3d67 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 29 Sep 2023 09:34:26 -0700 Subject: [PATCH 5/6] revert trino changes --- superset/db_engine_specs/trino.py | 8 +--- .../unit_tests/db_engine_specs/test_trino.py | 48 ++----------------- 2 files changed, 4 insertions(+), 52 deletions(-) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 68c96ead5638b..425137e302e6b 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -20,14 +20,12 @@ import logging import threading import time -from decimal import Decimal -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import simplejson as json from flask import current_app from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session -from sqlalchemy.types import DECIMAL, TypeEngine from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe @@ -51,10 +49,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): engine_name = "Trino" allows_alias_to_source_column = False - column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = { - DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val - } - @classmethod def extra_table_metadata( cls, diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index c8966508f6a63..1b50a683a0841 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,12 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access -from __future__ import annotations - import json from datetime import datetime -from decimal import Decimal -from typing import Any +from typing import Any, Optional from unittest.mock import Mock, patch import pandas as pd @@ -247,7 +244,7 @@ def test_auth_custom_auth_denied() -> None: def test_get_column_spec( native_type: str, sqla_type: type[types.TypeEngine], - attrs: dict[str, Any] | None, + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -276,7 +273,7 @@ def test_get_column_spec( ) def test_convert_dttm( target_type: str, - expected_result: str | None, + expected_result: Optional[str], dttm: datetime, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec @@ -398,42 +395,3 @@ def _mock_execute(*args, **kwargs): mock_query.set_extra_json_key.assert_called_once_with( key=QUERY_CANCEL_KEY, value=query_id ) - - -@pytest.mark.parametrize( - "data,description,expected_result", - [ - ( - [["1.23456", "abc"]], - [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(Decimal("1.23456"), "abc")], - ), - ( - [[Decimal("1.23456"), "abc"]], - [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(Decimal("1.23456"), "abc")], - ), - ( - [[None, "abc"]], - [("dec", "decimal(12,6)"), ("str", "varchar(3)")], - [(None, "abc")], - ), - ( - [["1.23456", "abc"]], - [("dec", "varchar(255)"), ("str", "varchar(3)")], - [["1.23456", "abc"]], - ), - ], -) -def test_column_type_mutator( - data: list[tuple[Any, ...] | list[Any]], - description: list[Any], - expected_result: list[tuple[Any, ...] | list[Any]], -): - from superset.db_engine_specs.trino import TrinoEngineSpec as spec - - mock_cursor = Mock() - mock_cursor.fetchall.return_value = data - mock_cursor.description = description - - assert spec.fetch_data(mock_cursor) == expected_result From 9659991317664c21056c75166ef696e574796aba Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 29 Sep 2023 09:53:06 -0700 Subject: [PATCH 6/6] fix test --- tests/unit_tests/db_engine_specs/test_mysql.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 2b32cef4e3415..ed643470176ec 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -227,24 +227,24 @@ def test_get_schema_from_engine_params() -> None: "data,description,expected_result", [ ( - [["1.23456", "abc"]], + [("1.23456", "abc")], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], [(Decimal("1.23456"), "abc")], ), ( - [[Decimal("1.23456"), "abc"]], + [(Decimal("1.23456"), "abc")], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], [(Decimal("1.23456"), "abc")], ), ( - [[None, "abc"]], + [(None, "abc")], [("dec", "decimal(12,6)"), ("str", "varchar(3)")], [(None, "abc")], ), ( - [["1.23456", "abc"]], + [("1.23456", "abc")], [("dec", "varchar(255)"), ("str", "varchar(3)")], - [["1.23456", "abc"]], + [("1.23456", "abc")], ), ], ) @@ -253,7 +253,7 @@ def test_column_type_mutator( description: list[Any], expected_result: list[tuple[Any, ...]], ): - from superset.db_engine_specs.trino import TrinoEngineSpec as spec + from superset.db_engine_specs.mysql import MySQLEngineSpec as spec mock_cursor = Mock() mock_cursor.fetchall.return_value = data