Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: duckdb support #422

Merged
merged 7 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install $REQUIREMENTS
python -m pip install -r requirements-test.txt
# step to test duckdb
# TODO: move these requirements into the test matrix
pip install duckdb_engine
python -m pip install .
env:
REQUIREMENTS: ${{ matrix.requirements }}
Expand Down
2 changes: 1 addition & 1 deletion siuba/ops/support/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from siuba.siu import FunctionLookupBound
from siuba.sql.utils import get_dialect_translator

SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery", "snowflake"]
SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery", "snowflake", "duckdb"]
ALL_BACKENDS = SQL_BACKENDS + ["pandas"]

methods = pd.DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion siuba/sql/dialects/_dt_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def date_trunc(_, col, period):

@symbolic_dispatch(cls = SqlColumn)
def sql_func_last_day_in_period(codata, col, period):
return date_trunc(codata, col, period) + sql.text("interval '1 %s - 1 day'" % period)
return date_trunc(codata, col, period) + sql.text("INTERVAL '1 %s' - INTERVAL '1 day'" % period)


# TODO: RENAME TO GEN
Expand Down
4 changes: 2 additions & 2 deletions siuba/sql/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def req_bool(f):
cummax = win_cumul("max"),
cummin = win_cumul("min"),
#cumprod =
cumsum = annotate(win_cumul("sum"), result_type = "float"),
cumsum = annotate(win_cumul("sum"), result_type = "variable"),
diff = sql_func_diff,
#is_monotonic =
#is_monotonic_decreasing =
Expand Down Expand Up @@ -397,7 +397,7 @@ def req_bool(f):
#sem =
#skew =
#std = # TODO(pg)
sum = annotate(win_agg("sum"), result_type = "float"),
sum = annotate(win_agg("sum"), result_type = "variable"),
#var = # TODO(pg)


Expand Down
108 changes: 108 additions & 0 deletions siuba/sql/dialects/duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from sqlalchemy.sql import func as fn
from sqlalchemy import sql

from ..translate import (
# data types
SqlColumn, SqlColumnAgg,
AggOver,
# transformations
wrap_annotate,
sql_agg,
win_agg,
win_cumul,
sql_not_impl,
# wiring up translator
extend_base,
SqlTranslator
)

from .postgresql import (
PostgresqlColumn,
PostgresqlColumnAgg,
)

from .base import sql_func_rank


# Data ========================================================================

class DuckdbColumn(PostgresqlColumn): pass
class DuckdbColumnAgg(PostgresqlColumnAgg, DuckdbColumn): pass


# Annotations =================================================================

def returns_int(func_names):
# TODO: MC-NOTE - shift all translations to directly register
# TODO: MC-NOTE - make an AliasAnnotated class or something, that signals
# it is using another method, but w/ an updated annotation.
from siuba.ops import ALL_OPS

for name in func_names:
generic = ALL_OPS[name]
f_concrete = generic.dispatch(SqlColumn)
f_annotated = wrap_annotate(f_concrete, result_type="int")
generic.register(DuckdbColumn, f_annotated)


# Translations ================================================================


def sql_quantile(is_analytic=False):
# Ordered and theoretical set aggregates
sa_func = getattr(sql.func, "percentile_cont")

def f_quantile(codata, col, q, *args):
if args:
raise NotImplementedError("Quantile only supports the q argument.")
if not isinstance(q, (int, float)):
raise TypeError("q argument must be int or float, but received: %s" %type(q))

# as far as I can tell, there's no easy way to tell sqlalchemy to render
# the exact text a dialect would render for a literal (except maybe using
# literal_column), so use the classic sql.text.
q_text = sql.text(str(q))

if is_analytic:
return AggOver(sa_func(sql.text(q_text)).within_group(col))

return sa_func(q_text).within_group(col)

return f_quantile


# scalar ----

extend_base(
DuckdbColumn,
**{
"str.contains": lambda _, col, re: fn.regexp_matches(col, re),
"str.title": sql_not_impl(),
}
)

returns_int([
"dt.day", "dt.dayofyear", "dt.days_in_month",
"dt.daysinmonth", "dt.hour", "dt.minute", "dt.month",
"dt.quarter", "dt.second", "dt.week",
"dt.weekofyear", "dt.year"
])

# window ----

extend_base(
DuckdbColumn,
rank = sql_func_rank,
#quantile = sql_quantile(is_analytic=True),
)


# aggregate ----

extend_base(
DuckdbColumnAgg,
quantile = sql_quantile(),
)


translator = SqlTranslator.from_mappings(DuckdbColumn, DuckdbColumnAgg)
2 changes: 1 addition & 1 deletion siuba/sql/dialects/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PostgresqlColumnAgg(SqlColumnAgg, PostgresqlColumn): pass

@annotate(return_type="float")
def sql_is_quarter_end(_, col):
last_day = fn.date_trunc("quarter", col) + sql.text("interval '3 month - 1 day'")
last_day = fn.date_trunc("quarter", col) + sql.text("INTERVAL '3 month' - INTERVAL '1 day'")
return fn.date_trunc("day", col) == last_day


Expand Down
3 changes: 2 additions & 1 deletion siuba/sql/dply/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..dialects.sqlite import SqliteColumn
from ..dialects.mysql import MysqlColumn
from ..dialects.bigquery import BigqueryColumn
from ..dialects.duckdb import DuckdbColumn

from siuba.dply.vector import (
#cumall, cumany, cummean,
Expand Down Expand Up @@ -99,7 +100,7 @@ def f(_, col, na_option = None) -> RankOver:
dense_rank .register(BigqueryColumn, _sql_rank("dense_rank", nulls_last = True))
percent_rank.register(BigqueryColumn, _sql_rank("percent_rank", nulls_last = True))


dense_rank .register(DuckdbColumn, _sql_rank("dense_rank", nulls_last = True))

# row_number ------------------------------------------------------------------

Expand Down
29 changes: 27 additions & 2 deletions siuba/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# once we drop sqlalchemy 1.2, can use create_mock_engine function
from sqlalchemy.engine.mock import MockConnection
except ImportError:
# monkey patch old sqlalchemy mock, so it can be a context handler
from sqlalchemy.engine.strategies import MockEngineStrategy
MockConnection = MockEngineStrategy.MockConnection

Expand Down Expand Up @@ -49,10 +48,31 @@ def mock_sqlalchemy_engine(dialect):

from sqlalchemy.engine import Engine
from sqlalchemy.dialects import registry
from types import ModuleType

# TODO: can be removed once v1.3.18 support dropped
from sqlalchemy.engine.url import URL


dialect_cls = registry.load(dialect)

# there is probably a better way to do this, but for some reason duckdb
# returns a module, rather than the dialect class itself. By convention,
# dialect modules expose a variable named dialect, so we grab that.
if isinstance(dialect_cls, ModuleType):
dialect_cls = dialect_cls.dialect

return MockConnection(dialect_cls(), lambda *args, **kwargs: None)
conn = MockConnection(dialect_cls(), lambda *args, **kwargs: None)

# set a url on it, so that LazyTbl can read the backend name.
if is_sqla_12() or is_sqla_13():
url = URL(drivername=dialect)
else:
url = URL.create(drivername=dialect)

conn.url = url

return conn


# Temporary fix for pandas bug (https://github.com/pandas-dev/pandas/issues/35484)
Expand All @@ -63,6 +83,11 @@ def execute(self, *args, **kwargs):
return self.connectable.execute(*args, **kwargs)


# Detect duckdb for temporary workarounds -------------------------------------

def _is_dialect_duckdb(engine):
return engine.url.get_backend_name() == "duckdb"

# Backwards compatibility for sqlalchemy 1.3 ----------------------------------

import re
Expand Down
43 changes: 30 additions & 13 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .utils import (
get_dialect_translator,
_FixedSqlDatabase,
_is_dialect_duckdb,
_sql_select,
_sql_column_collection,
_sql_add_columns,
Expand Down Expand Up @@ -281,7 +282,8 @@ def __init__(
# connection and dialect specific functions
self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source

dialect = self.source.dialect.name
# get dialect name
dialect = self.source.url.get_backend_name()
self.translator = get_dialect_translator(dialect)

self.tbl = self._create_table(tbl, columns, self.source)
Expand Down Expand Up @@ -467,29 +469,44 @@ def _show_query(tbl, simplify = False):
@collect.register(LazyTbl)
def _collect(__data, as_df = True):
# TODO: maybe remove as_df options, always return dataframe
# normally can just pass the sql objects to execute, but for some reason
# psycopg2 completes about incomplete template.
# see https://stackoverflow.com/a/47193568/1144523

#query = __data.last_op
#compiled = query.compile(
# dialect = __data.source.dialect,
# compile_kwargs = {"literal_binds": True}
#)

if isinstance(__data.source, MockConnection):
# a mock sqlalchemy is being used to show_query, and echo queries.
# it doesn't return a result object or have a context handler, so
# we need to bail out early
return

# compile query ----

if _is_dialect_duckdb(__data.source):
# TODO: can be removed once next release of duckdb fixes:
# https://github.com/duckdb/duckdb/issues/2972
query = __data.last_op
compiled = query.compile(
dialect = __data.source.dialect,
compile_kwargs = {"literal_binds": True}
)
else:
compiled = __data.last_op

# execute query ----

with __data.source.connect() as conn:
if as_df:
sql_db = _FixedSqlDatabase(conn)

return sql_db.read_sql(__data.last_op)

return conn.execute(__data.last_op)
if _is_dialect_duckdb(__data.source):
# TODO: pandas read_sql is very slow with duckdb.
# see https://github.com/pandas-dev/pandas/issues/45678
# going to handle here for now. address once LazyTbl gets
# subclassed per backend.
duckdb_con = conn.connection.c
return duckdb_con.query(str(compiled)).to_df()
else:
#
return sql_db.read_sql(compiled)

return conn.execute(compiled)


@select.register(LazyTbl)
Expand Down
1 change: 1 addition & 0 deletions siuba/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def pytest_addoption(parser):
pytest.param(lambda: SqlBackend("postgresql"), id = "postgresql", marks=pytest.mark.postgresql),
pytest.param(lambda: SqlBackend("mysql"), id = "mysql", marks=pytest.mark.mysql),
pytest.param(lambda: SqlBackend("sqlite"), id = "sqlite", marks=pytest.mark.sqlite),
pytest.param(lambda: SqlBackend("duckdb"), id = "duckdb", marks=pytest.mark.duckdb),
pytest.param(lambda: BigqueryBackend("bigquery"), id = "bigquery", marks=pytest.mark.bigquery),
pytest.param(lambda: CloudBackend("snowflake"), id = "snowflake", marks=pytest.mark.snowflake),
pytest.param(lambda: PandasBackend("pandas"), id = "pandas", marks=pytest.mark.pandas)
Expand Down
Loading