From 456dcb881aa46cca4504f3dd9c3cfc8f0cae9037 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 2 May 2022 22:41:12 -0400 Subject: [PATCH 1/7] feat: duckdb support --- .github/workflows/ci.yml | 3 + siuba/ops/support/base.py | 2 +- siuba/sql/dialects/_dt_generics.py | 2 +- siuba/sql/dialects/duckdb.py | 111 +++++++++++++++++++++++++++++ siuba/sql/dialects/postgresql.py | 2 +- siuba/sql/dply/vector.py | 3 +- siuba/sql/utils.py | 7 ++ siuba/sql/verbs.py | 19 +++-- siuba/tests/conftest.py | 1 + siuba/tests/helpers.py | 32 +++++++-- siuba/tests/test_sql_misc.py | 12 +++- siuba/tests/test_verb_count.py | 4 +- 12 files changed, 178 insertions(+), 20 deletions(-) create mode 100644 siuba/sql/dialects/duckdb.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 044bb81c..f8edacf4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,9 @@ jobs: python -m pip install --upgrade pip python -m pip install $REQUIREMENTS python -m pip install -r requirements-test.txt + # step to test duckdb + # TODO: move these requirements into the test matrix + pip install duckdb_engine python -m pip install . env: REQUIREMENTS: ${{ matrix.requirements }} diff --git a/siuba/ops/support/base.py b/siuba/ops/support/base.py index 538338b3..d2529bab 100644 --- a/siuba/ops/support/base.py +++ b/siuba/ops/support/base.py @@ -8,7 +8,7 @@ from siuba.siu import FunctionLookupBound from siuba.sql.utils import get_dialect_translator -SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery", "snowflake"] +SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery", "snowflake", "duckdb"] ALL_BACKENDS = SQL_BACKENDS + ["pandas"] methods = pd.DataFrame( diff --git a/siuba/sql/dialects/_dt_generics.py b/siuba/sql/dialects/_dt_generics.py index 6e4847a3..7fb95def 100644 --- a/siuba/sql/dialects/_dt_generics.py +++ b/siuba/sql/dialects/_dt_generics.py @@ -14,7 +14,7 @@ def date_trunc(_, col, period): @symbolic_dispatch(cls = SqlColumn) def sql_func_last_day_in_period(codata, col, period): - return date_trunc(codata, col, period) + sql.text("interval '1 %s - 1 day'" % period) + return date_trunc(codata, col, period) + sql.text("INTERVAL '1 %s' - INTERVAL '1 day'" % period) # TODO: RENAME TO GEN diff --git a/siuba/sql/dialects/duckdb.py b/siuba/sql/dialects/duckdb.py new file mode 100644 index 00000000..c6c266fd --- /dev/null +++ b/siuba/sql/dialects/duckdb.py @@ -0,0 +1,111 @@ +from sqlalchemy.sql import func as fn +from sqlalchemy import sql + +from ..translate import ( + # data types + SqlColumn, SqlColumnAgg, + AggOver, + # transformations + wrap_annotate, + sql_agg, + win_agg, + win_cumul, + sql_not_impl, + # wiring up translator + extend_base, + SqlTranslator +) + +from .postgresql import ( + PostgresqlColumn, + PostgresqlColumnAgg, +) + +from .base import sql_func_rank + + +# Data ======================================================================== + +class DuckdbColumn(PostgresqlColumn): pass +class DuckdbColumnAgg(PostgresqlColumnAgg, DuckdbColumn): pass + + +# Annotations ================================================================= + +def returns_int(func_names): + # TODO: MC-NOTE - shift all translations to directly register + # TODO: MC-NOTE - make an AliasAnnotated class or something, that signals + # it is using another method, but w/ an updated annotation. + from siuba.ops import ALL_OPS + + for name in func_names: + generic = ALL_OPS[name] + f_concrete = generic.dispatch(SqlColumn) + f_annotated = wrap_annotate(f_concrete, result_type="int") + generic.register(DuckdbColumn, f_annotated) + + +# Translations ================================================================ + + +def sql_quantile(is_analytic=False): + # Ordered and theoretical set aggregates + sa_func = getattr(sql.func, "percentile_cont") + + def f_quantile(codata, col, q, *args): + if args: + raise NotImplementedError("Quantile only supports the q argument.") + if not isinstance(q, (int, float)): + raise TypeError("q argument must be int or float, but received: %s" %type(q)) + + # as far as I can tell, there's no easy way to tell sqlalchemy to render + # the exact text a dialect would render for a literal (except maybe using + # literal_column), so use the classic sql.text. + q_text = sql.text(str(q)) + + if is_analytic: + return AggOver(sa_func(sql.text(q_text)).within_group(col)) + + return sa_func(q_text).within_group(col) + + return f_quantile + + +# scalar ---- + +extend_base( + DuckdbColumn, + **{ + "str.contains": lambda _, col, re: fn.regexp_matches(col, re), + "str.title": sql_not_impl(), + } +) + +returns_int([ + "dt.day", "dt.dayofyear", "dt.days_in_month", + "dt.daysinmonth", "dt.hour", "dt.minute", "dt.month", + "dt.quarter", "dt.second", "dt.week", + "dt.weekofyear", "dt.year" +]) + +# window ---- + +extend_base( + DuckdbColumn, + cumsum = win_cumul("sum"), + sum = win_agg("sum"), + rank = sql_func_rank, + #quantile = sql_quantile(is_analytic=True), +) + + +# aggregate ---- + +extend_base( + DuckdbColumnAgg, + sum = sql_agg("sum"), + quantile = sql_quantile(), +) + + +translator = SqlTranslator.from_mappings(DuckdbColumn, DuckdbColumnAgg) diff --git a/siuba/sql/dialects/postgresql.py b/siuba/sql/dialects/postgresql.py index 5de008ae..dc2d9952 100644 --- a/siuba/sql/dialects/postgresql.py +++ b/siuba/sql/dialects/postgresql.py @@ -31,7 +31,7 @@ class PostgresqlColumnAgg(SqlColumnAgg, PostgresqlColumn): pass @annotate(return_type="float") def sql_is_quarter_end(_, col): - last_day = fn.date_trunc("quarter", col) + sql.text("interval '3 month - 1 day'") + last_day = fn.date_trunc("quarter", col) + sql.text("INTERVAL '3 month' - INTERVAL '1 day'") return fn.date_trunc("day", col) == last_day diff --git a/siuba/sql/dply/vector.py b/siuba/sql/dply/vector.py index 698d6d5d..d29c4fd2 100644 --- a/siuba/sql/dply/vector.py +++ b/siuba/sql/dply/vector.py @@ -12,6 +12,7 @@ from ..dialects.sqlite import SqliteColumn from ..dialects.mysql import MysqlColumn from ..dialects.bigquery import BigqueryColumn +from ..dialects.duckdb import DuckdbColumn from siuba.dply.vector import ( #cumall, cumany, cummean, @@ -99,7 +100,7 @@ def f(_, col, na_option = None) -> RankOver: dense_rank .register(BigqueryColumn, _sql_rank("dense_rank", nulls_last = True)) percent_rank.register(BigqueryColumn, _sql_rank("percent_rank", nulls_last = True)) - +dense_rank .register(DuckdbColumn, _sql_rank("dense_rank", nulls_last = True)) # row_number ------------------------------------------------------------------ diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index c6391690..32829eea 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -49,8 +49,15 @@ def mock_sqlalchemy_engine(dialect): from sqlalchemy.engine import Engine from sqlalchemy.dialects import registry + from types import ModuleType dialect_cls = registry.load(dialect) + + # there is probably a better way to do this, but for some reason duckdb + # returns a module, rather than the dialect class itself. By convention, + # dialect modules expose a variable named dialect, so we grab that. + if isinstance(dialect_cls, ModuleType): + dialect_cls = dialect_cls.dialect return MockConnection(dialect_cls(), lambda *args, **kwargs: None) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 5749d573..67aac8bf 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -471,11 +471,16 @@ def _collect(__data, as_df = True): # psycopg2 completes about incomplete template. # see https://stackoverflow.com/a/47193568/1144523 - #query = __data.last_op - #compiled = query.compile( - # dialect = __data.source.dialect, - # compile_kwargs = {"literal_binds": True} - #) + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + if __data.source.engine.dialect.name == "duckdb": + query = __data.last_op + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_op if isinstance(__data.source, MockConnection): # a mock sqlalchemy is being used to show_query, and echo queries. @@ -487,9 +492,9 @@ def _collect(__data, as_df = True): if as_df: sql_db = _FixedSqlDatabase(conn) - return sql_db.read_sql(__data.last_op) + return sql_db.read_sql(compiled) - return conn.execute(__data.last_op) + return conn.execute(compiled) @select.register(LazyTbl) diff --git a/siuba/tests/conftest.py b/siuba/tests/conftest.py index 38aeaae8..1e83f628 100644 --- a/siuba/tests/conftest.py +++ b/siuba/tests/conftest.py @@ -17,6 +17,7 @@ def pytest_addoption(parser): pytest.param(lambda: SqlBackend("postgresql"), id = "postgresql", marks=pytest.mark.postgresql), pytest.param(lambda: SqlBackend("mysql"), id = "mysql", marks=pytest.mark.mysql), pytest.param(lambda: SqlBackend("sqlite"), id = "sqlite", marks=pytest.mark.sqlite), + pytest.param(lambda: SqlBackend("duckdb"), id = "duckdb", marks=pytest.mark.duckdb), pytest.param(lambda: BigqueryBackend("bigquery"), id = "bigquery", marks=pytest.mark.bigquery), pytest.param(lambda: CloudBackend("snowflake"), id = "snowflake", marks=pytest.mark.snowflake), pytest.param(lambda: PandasBackend("pandas"), id = "pandas", marks=pytest.mark.pandas) diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 09c90b0e..018c7391 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -70,10 +70,22 @@ def data_frame(*args, _index = None, **kwargs): "password": "", "host": "", "options": "" - } + }, + "duckdb": { + "dialect": "duckdb", + "driver": "", + "dbname": ":memory:", + "port": "", + "user": "", + "password": "", + "host": "", + }, } class Backend: + # TODO: use abstract base class + kind = "pandas" + def __init__(self, name): self.name = name @@ -91,13 +103,18 @@ def load_df(self, df = None, **kwargs): def load_cached_df(self, df): return df + def matches_test_qualifier(self, s): + return self.name == s or self.kind == s + def __repr__(self): return "{0}({1})".format(self.__class__.__name__, repr(self.name)) class PandasBackend(Backend): - pass + kind = "pandas" class SqlBackend(Backend): + kind = "sql" + table_name_indx = 0 # if there is a :, sqlalchemy tries to parse the port number. @@ -105,6 +122,8 @@ class SqlBackend(Backend): # later on the port value passed in. sa_conn_fmt = "{dialect}://{user}:{password}@{host}{port}/{dbname}?{options}" + sa_conn_memory_fmt = "{dialect}:///{dbname}" + def __init__(self, name): from urllib.parse import quote_plus @@ -117,8 +136,13 @@ def __init__(self, name): if params["password"]: params["password"] = quote_plus(params["password"]) + if params["dbname"] == ":memory:": + sa_conn_uri = self.sa_conn_memory_fmt.format(**params) + else: + sa_conn_uri = self.sa_conn_fmt.format(**params) + self.name = name - self.engine = sqla.create_engine(self.sa_conn_fmt.format(**params)) + self.engine = sqla.create_engine(sa_conn_uri) self.cache = {} def dispose(self): @@ -265,7 +289,7 @@ def backend_notimpl(*names): def outer(f): @wraps(f) def wrapper(backend, *args, **kwargs): - if backend.name in names: + if any(map(backend.matches_test_qualifier, names)): with pytest.raises((NotImplementedError, FunctionLookupError)): f(backend, *args, **kwargs) pytest.xfail("Not implemented!") diff --git a/siuba/tests/test_sql_misc.py b/siuba/tests/test_sql_misc.py index 6748978e..8a89a52a 100644 --- a/siuba/tests/test_sql_misc.py +++ b/siuba/tests/test_sql_misc.py @@ -35,8 +35,14 @@ def test_raw_sql_mutate_grouped(backend, df): @backend_sql def test_raw_sql_mutate_refer_previous_raise_dberror(backend, skip_backend, df): # Note: unlikely will be able to support this case. Normally we analyze - # the expression to know whether we need to create a subquery. - with pytest.raises(sqlalchemy.exc.DatabaseError): + if backend.name == "duckdb": + # duckdb dialect re-raises the engines exception, which is RuntimeError + # the expression to know whether we need to create a subquery. + exc = RuntimeError + else: + exc = sqlalchemy.exc.DatabaseError + + with pytest.raises(exc): assert_equal_query( df, group_by("x") >> mutate(z1 = sql_raw("y + 1"), z2 = sql_raw("z1 + 1")), @@ -44,7 +50,7 @@ def test_raw_sql_mutate_refer_previous_raise_dberror(backend, skip_backend, df): ) -@pytest.mark.xfail_backend("postgresql", "mysql", "bigquery", "sqlite") +@pytest.mark.xfail_backend("postgresql", "mysql", "bigquery", "sqlite", "duckdb") @backend_sql def test_raw_sql_mutate_refer_previous_succeeds(backend, xfail_backend, df): assert_equal_query( diff --git a/siuba/tests/test_verb_count.py b/siuba/tests/test_verb_count.py index c2be4d2d..4e8e0518 100644 --- a/siuba/tests/test_verb_count.py +++ b/siuba/tests/test_verb_count.py @@ -49,7 +49,7 @@ def test_count_with_kwarg_expression(df): pd.DataFrame({"y": [0], "n": [4]}) ) -@backend_notimpl("sqlite", "postgresql", "mysql", "bigquery", "snowflake") # see (#104) +@backend_notimpl("sql") # see (#104) def test_count_wt(backend, df): assert_equal_query( df, @@ -65,7 +65,7 @@ def test_count_no_groups(df): pd.DataFrame({'n': [4]}) ) -@backend_notimpl("sqlite", "postgresql", "mysql", "bigquery", "snowflake") # see (#104) +@backend_notimpl("sql") # see (#104) def test_count_no_groups_wt(backend, df): assert_equal_query( df, From 15175250ba70e376c981764f638a5cfcc660cf41 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 2 May 2022 22:52:30 -0400 Subject: [PATCH 2/7] ci: install duckdb_engine from repo for now --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f8edacf4..8dfe1ba8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: python -m pip install -r requirements-test.txt # step to test duckdb # TODO: move these requirements into the test matrix - pip install duckdb_engine + pip install git+https://github.com/Mause/duckdb_engine.git python -m pip install . env: REQUIREMENTS: ${{ matrix.requirements }} From 6ea9e879d633efbadf225c0a21286c0e98146e22 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 30 May 2022 19:57:56 -0400 Subject: [PATCH 3/7] feat(sql): speed up duckdb query collection --- siuba/sql/dialects/base.py | 4 ++-- siuba/sql/dialects/duckdb.py | 3 --- siuba/sql/verbs.py | 32 +++++++++++++++++++++----------- siuba/tests/helpers.py | 10 ++++++++++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/siuba/sql/dialects/base.py b/siuba/sql/dialects/base.py index babf6d75..94d40f77 100644 --- a/siuba/sql/dialects/base.py +++ b/siuba/sql/dialects/base.py @@ -368,7 +368,7 @@ def req_bool(f): cummax = win_cumul("max"), cummin = win_cumul("min"), #cumprod = - cumsum = annotate(win_cumul("sum"), result_type = "float"), + cumsum = annotate(win_cumul("sum"), result_type = "variable"), diff = sql_func_diff, #is_monotonic = #is_monotonic_decreasing = @@ -397,7 +397,7 @@ def req_bool(f): #sem = #skew = #std = # TODO(pg) - sum = annotate(win_agg("sum"), result_type = "float"), + sum = annotate(win_agg("sum"), result_type = "variable"), #var = # TODO(pg) diff --git a/siuba/sql/dialects/duckdb.py b/siuba/sql/dialects/duckdb.py index c6c266fd..aa132b10 100644 --- a/siuba/sql/dialects/duckdb.py +++ b/siuba/sql/dialects/duckdb.py @@ -92,8 +92,6 @@ def f_quantile(codata, col, q, *args): extend_base( DuckdbColumn, - cumsum = win_cumul("sum"), - sum = win_agg("sum"), rank = sql_func_rank, #quantile = sql_quantile(is_analytic=True), ) @@ -103,7 +101,6 @@ def f_quantile(codata, col, q, *args): extend_base( DuckdbColumnAgg, - sum = sql_agg("sum"), quantile = sql_quantile(), ) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 67aac8bf..1b0d1c66 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -467,13 +467,18 @@ def _show_query(tbl, simplify = False): @collect.register(LazyTbl) def _collect(__data, as_df = True): # TODO: maybe remove as_df options, always return dataframe - # normally can just pass the sql objects to execute, but for some reason - # psycopg2 completes about incomplete template. - # see https://stackoverflow.com/a/47193568/1144523 - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + if __data.source.engine.dialect.name == "duckdb": + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 query = __data.last_op compiled = query.compile( dialect = __data.source.dialect, @@ -482,17 +487,22 @@ def _collect(__data, as_df = True): else: compiled = __data.last_op - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return + # execute query ---- with __data.source.connect() as conn: if as_df: sql_db = _FixedSqlDatabase(conn) - return sql_db.read_sql(compiled) + if __data.source.engine.dialect.name == "duckdb": + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) return conn.execute(compiled) diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 018c7391..5d9d3bda 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -226,6 +226,16 @@ def assert_equal_query(tbl, lazy_query, target, **kwargs): out = collect(lazy_query(tbl)) + if isinstance(tbl, LazyTbl) and tbl.source.dialect.name == "duckdb": + # TODO: find a nice way to remove duckdb specific code from here + # duckdb does not use pandas.DataFrame.to_sql method, which coerces + # everything to 64 bit. So we need to coerce any results it returns + # as 32 bit to 64 bit, to match to_sql. + int_cols = out.select_dtypes('int').columns + flt_cols = out.select_dtypes('float').columns + out[int_cols] = out[int_cols].astype('int64') + out[flt_cols] = out[flt_cols].astype('float64') + if isinstance(tbl, (pd.DataFrame, DataFrameGroupBy)): df_a = ungroup(out).reset_index(drop = True) df_b = ungroup(target).reset_index(drop = True) From ca79b3958802c05e4200c44710985f48e3cd1e36 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 30 May 2022 20:42:48 -0400 Subject: [PATCH 4/7] fix(sql): use sqlalchemy engine url to get dialect name --- siuba/sql/utils.py | 12 +++++++++--- siuba/sql/verbs.py | 8 +++++--- siuba/tests/helpers.py | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 32829eea..371446ad 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -4,7 +4,6 @@ # once we drop sqlalchemy 1.2, can use create_mock_engine function from sqlalchemy.engine.mock import MockConnection except ImportError: - # monkey patch old sqlalchemy mock, so it can be a context handler from sqlalchemy.engine.strategies import MockEngineStrategy MockConnection = MockEngineStrategy.MockConnection @@ -47,7 +46,7 @@ def mock_sqlalchemy_engine(dialect): """ - from sqlalchemy.engine import Engine + from sqlalchemy.engine import Engine, URL from sqlalchemy.dialects import registry from types import ModuleType @@ -59,7 +58,9 @@ def mock_sqlalchemy_engine(dialect): if isinstance(dialect_cls, ModuleType): dialect_cls = dialect_cls.dialect - return MockConnection(dialect_cls(), lambda *args, **kwargs: None) + conn = MockConnection(dialect_cls(), lambda *args, **kwargs: None) + conn.url = URL.create(drivername=dialect) + return conn # Temporary fix for pandas bug (https://github.com/pandas-dev/pandas/issues/35484) @@ -70,6 +71,11 @@ def execute(self, *args, **kwargs): return self.connectable.execute(*args, **kwargs) +# Detect duckdb for temporary workarounds ------------------------------------- + +def _is_dialect_duckdb(engine): + return engine.url.get_backend_name() == "duckdb" + # Backwards compatibility for sqlalchemy 1.3 ---------------------------------- import re diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 1b0d1c66..60af12c5 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -32,6 +32,7 @@ from .utils import ( get_dialect_translator, _FixedSqlDatabase, + _is_dialect_duckdb, _sql_select, _sql_column_collection, _sql_add_columns, @@ -281,7 +282,8 @@ def __init__( # connection and dialect specific functions self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - dialect = self.source.dialect.name + # get dialect name + dialect = self.source.url.get_backend_name() self.translator = get_dialect_translator(dialect) self.tbl = self._create_table(tbl, columns, self.source) @@ -476,7 +478,7 @@ def _collect(__data, as_df = True): # compile query ---- - if __data.source.engine.dialect.name == "duckdb": + if _is_dialect_duckdb(__data.source): # TODO: can be removed once next release of duckdb fixes: # https://github.com/duckdb/duckdb/issues/2972 query = __data.last_op @@ -493,7 +495,7 @@ def _collect(__data, as_df = True): if as_df: sql_db = _FixedSqlDatabase(conn) - if __data.source.engine.dialect.name == "duckdb": + if _is_dialect_duckdb(__data.source): # TODO: pandas read_sql is very slow with duckdb. # see https://github.com/pandas-dev/pandas/issues/45678 # going to handle here for now. address once LazyTbl gets diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 5d9d3bda..792b5783 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -2,6 +2,7 @@ import uuid from siuba.sql import LazyTbl +from siuba.sql.utils import _is_dialect_duckdb from siuba.dply.verbs import ungroup, collect from siuba.siu import FunctionLookupError from pandas.testing import assert_frame_equal @@ -226,7 +227,7 @@ def assert_equal_query(tbl, lazy_query, target, **kwargs): out = collect(lazy_query(tbl)) - if isinstance(tbl, LazyTbl) and tbl.source.dialect.name == "duckdb": + if isinstance(tbl, LazyTbl) and _is_dialect_duckdb(tbl.source): # TODO: find a nice way to remove duckdb specific code from here # duckdb does not use pandas.DataFrame.to_sql method, which coerces # everything to 64 bit. So we need to coerce any results it returns From 507bb98ee1f4538469a516c61001a609ce00fcb8 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 30 May 2022 21:10:53 -0400 Subject: [PATCH 5/7] tests: fix duckdb type conversions in tests --- siuba/tests/helpers.py | 4 ++-- siuba/tests/test_sql_verbs.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 792b5783..b55b6033 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -232,8 +232,8 @@ def assert_equal_query(tbl, lazy_query, target, **kwargs): # duckdb does not use pandas.DataFrame.to_sql method, which coerces # everything to 64 bit. So we need to coerce any results it returns # as 32 bit to 64 bit, to match to_sql. - int_cols = out.select_dtypes('int').columns - flt_cols = out.select_dtypes('float').columns + int_cols = out.select_dtypes('int32').columns + flt_cols = out.select_dtypes('float32').columns out[int_cols] = out[int_cols].astype('int64') out[flt_cols] = out[flt_cols].astype('float64') diff --git a/siuba/tests/test_sql_verbs.py b/siuba/tests/test_sql_verbs.py index 0b24d9ee..4870ef16 100644 --- a/siuba/tests/test_sql_verbs.py +++ b/siuba/tests/test_sql_verbs.py @@ -32,15 +32,16 @@ def db(): metadata.create_all(engine) - conn = engine.connect() + with engine.connect() as conn: - ins = users.insert().values(name='jack', fullname='Jack Jones') - result = conn.execute(ins) + ins = users.insert().values(name='jack', fullname='Jack Jones') + result = conn.execute(ins) - ins = users.insert() - conn.execute(ins, id=2, name='wendy', fullname='Wendy Williams') - yield conn + ins = users.insert() + conn.execute(ins, id=2, name='wendy', fullname='Wendy Williams') + + yield engine # LazyTbl --------------------------------------------------------------------- From 5aaa935c3f3cb0aa1017860d7691455536a984dd Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 30 May 2022 21:17:56 -0400 Subject: [PATCH 6/7] fix: support sqlalchemy 1.3.18 compat --- siuba/sql/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 371446ad..ed000103 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -46,10 +46,14 @@ def mock_sqlalchemy_engine(dialect): """ - from sqlalchemy.engine import Engine, URL + from sqlalchemy.engine import Engine from sqlalchemy.dialects import registry from types import ModuleType + # TODO: can be removed once v1.3.18 support dropped + from sqlalchemy.engine.url import URL + + dialect_cls = registry.load(dialect) # there is probably a better way to do this, but for some reason duckdb @@ -59,7 +63,15 @@ def mock_sqlalchemy_engine(dialect): dialect_cls = dialect_cls.dialect conn = MockConnection(dialect_cls(), lambda *args, **kwargs: None) - conn.url = URL.create(drivername=dialect) + + # set a url on it, so that LazyTbl can read the backend name. + if is_sqla_12() or is_sqla_13(): + url = URL(drivername=dialect) + else: + url = URL.create(drivername=dialect) + + conn.url = url + return conn From 5c3edfadd2ac0c8cacc9323efc6960de38b03027 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 30 May 2022 22:40:21 -0400 Subject: [PATCH 7/7] ci: do not install duckdb_engine from github --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8dfe1ba8..f8edacf4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: python -m pip install -r requirements-test.txt # step to test duckdb # TODO: move these requirements into the test matrix - pip install git+https://github.com/Mause/duckdb_engine.git + pip install duckdb_engine python -m pip install . env: REQUIREMENTS: ${{ matrix.requirements }}