From b23c0f5b4e22eb47a6876287158f7f5e99f19ff1 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 8 Feb 2025 09:20:55 -0500 Subject: [PATCH 01/45] feat: add ibis-yaml compiler - passes tpch-01-22 - unbound expr only - support builtin ScalarUDF --- pyproject.toml | 3 + python/letsql/ibis_yaml/__init__.py | 0 python/letsql/ibis_yaml/compiler.py | 15 + python/letsql/ibis_yaml/tests/__init__.py | 0 python/letsql/ibis_yaml/tests/conftest.py | 787 +++++++++++ .../letsql/ibis_yaml/tests/test_arithmetic.py | 96 ++ python/letsql/ibis_yaml/tests/test_basic.py | 156 +++ .../letsql/ibis_yaml/tests/test_join_chain.py | 101 ++ .../tests/test_operations_boolean.py | 104 ++ .../ibis_yaml/tests/test_operations_cast.py | 53 + .../tests/test_operations_datetime.py | 105 ++ .../letsql/ibis_yaml/tests/test_relations.py | 86 ++ .../letsql/ibis_yaml/tests/test_selection.py | 13 + .../letsql/ibis_yaml/tests/test_string_ops.py | 48 + .../letsql/ibis_yaml/tests/test_subquery.py | 42 + python/letsql/ibis_yaml/tests/test_tpch.py | 50 + python/letsql/ibis_yaml/tests/test_udf.py | 59 + .../ibis_yaml/tests/test_window_functions.py | 57 + python/letsql/ibis_yaml/translate.py | 1157 +++++++++++++++++ python/letsql/ibis_yaml/utils.py | 127 ++ requirements-dev.txt | 2 +- uv.lock | 22 + 22 files changed, 3082 insertions(+), 1 deletion(-) create mode 100644 python/letsql/ibis_yaml/__init__.py create mode 100644 python/letsql/ibis_yaml/compiler.py create mode 100644 python/letsql/ibis_yaml/tests/__init__.py create mode 100644 python/letsql/ibis_yaml/tests/conftest.py create mode 100644 python/letsql/ibis_yaml/tests/test_arithmetic.py create mode 100644 python/letsql/ibis_yaml/tests/test_basic.py create mode 100644 python/letsql/ibis_yaml/tests/test_join_chain.py create mode 100644 python/letsql/ibis_yaml/tests/test_operations_boolean.py create mode 100644 python/letsql/ibis_yaml/tests/test_operations_cast.py create mode 100644 python/letsql/ibis_yaml/tests/test_operations_datetime.py create mode 100644 python/letsql/ibis_yaml/tests/test_relations.py create mode 100644 python/letsql/ibis_yaml/tests/test_selection.py create mode 100644 python/letsql/ibis_yaml/tests/test_string_ops.py create mode 100644 python/letsql/ibis_yaml/tests/test_subquery.py create mode 100644 python/letsql/ibis_yaml/tests/test_tpch.py create mode 100644 python/letsql/ibis_yaml/tests/test_udf.py create mode 100644 python/letsql/ibis_yaml/tests/test_window_functions.py create mode 100644 python/letsql/ibis_yaml/translate.py create mode 100644 python/letsql/ibis_yaml/utils.py diff --git a/pyproject.toml b/pyproject.toml index efc20000..bfb9f909 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ "sqlglot==25.20.2", "toolz>=0.11", "typing-extensions>=4.3.0", + "hypothesis>=6.124.9", + "pyyaml>=6.0.2", + "cloudpickle>=3.1.1", ] requires-python = ">=3.10" authors = [ diff --git a/python/letsql/ibis_yaml/__init__.py b/python/letsql/ibis_yaml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py new file mode 100644 index 00000000..6610472b --- /dev/null +++ b/python/letsql/ibis_yaml/compiler.py @@ -0,0 +1,15 @@ +from letsql.ibis_yaml.translate import translate_from_yaml, translate_to_yaml + + +class IbisYamlCompiler: + def __init__(self): + pass + + def compile_to_yaml(self, expr): + self.current_relation = None + unbound_expr = expr.unbind() + return translate_to_yaml(unbound_expr.op(), self) + + def compile_from_yaml(self, yaml_dict): + self.current_relation = None + return translate_from_yaml(yaml_dict, self) diff --git a/python/letsql/ibis_yaml/tests/__init__.py b/python/letsql/ibis_yaml/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/letsql/ibis_yaml/tests/conftest.py b/python/letsql/ibis_yaml/tests/conftest.py new file mode 100644 index 00000000..082d411e --- /dev/null +++ b/python/letsql/ibis_yaml/tests/conftest.py @@ -0,0 +1,787 @@ +from datetime import date + +import ibis +import ibis.expr.datatypes as dt +import pytest + + +# Fixtures from: https://github.com/ibis-project/ibis-substrait/blob/main/ibis_substrait/tests/compiler/test_tpch.py + + +@pytest.fixture +def t(): + return ibis.table( + dict( + a="int64", + b="string", + c="float64", + d="timestamp", + e="date", + ), + name="test_table", + ) + + +@pytest.fixture +def lineitem(): + return ibis.table( + [ + ("l_orderkey", dt.int64), + ("l_partkey", dt.int64), + ("l_suppkey", dt.int64), + ("l_linenumber", dt.int64), + ("l_quantity", dt.Decimal(15, 2)), + ("l_extendedprice", dt.Decimal(15, 2)), + ("l_discount", dt.Decimal(15, 2)), + ("l_tax", dt.Decimal(15, 2)), + ("l_returnflag", dt.string), + ("l_linestatus", dt.string), + ("l_shipdate", dt.date), + ("l_commitdate", dt.date), + ("l_receiptdate", dt.date), + ("l_shipinstruct", dt.string), + ("l_shipmode", dt.string), + ("l_comment", dt.string), + ], + name="lineitem", + ) + + +@pytest.fixture +def orders(): + return ibis.table( + [ + ("o_orderkey", dt.int32(nullable=False)), + ("o_custkey", dt.int32(nullable=False)), + ("o_orderstatus", dt.string(nullable=False)), + ("o_totalprice", dt.Decimal(precision=15, scale=2, nullable=False)), + ("o_orderdate", dt.date(nullable=False)), + ("o_orderpriority", dt.string(nullable=False)), + ("o_clerk", dt.string(nullable=False)), + ("o_shippriority", dt.int32(nullable=False)), + ("o_comment", dt.string(nullable=False)), + ("o_year", dt.date), + ], + name="orders", + ) + + +@pytest.fixture +def partsupp(): + return ibis.table( + [ + ("ps_partkey", dt.int32(nullable=False)), + ("ps_suppkey", dt.int32(nullable=False)), + ("ps_availqty", dt.int32(nullable=False)), + ("ps_supplycost", dt.Decimal(precision=15, scale=2, nullable=False)), + ("ps_comment", dt.string(nullable=False)), + ], + name="partsupp", + ) + + +@pytest.fixture +def part(): + return ibis.table( + [ + ("p_partkey", dt.int32(nullable=False)), + ("p_name", dt.string(nullable=False)), + ("p_mfgr", dt.string(nullable=False)), + ("p_brand", dt.string(nullable=False)), + ("p_type", dt.string(nullable=False)), + ("p_size", dt.int32(nullable=False)), + ("p_container", dt.string(nullable=False)), + ("p_retailprice", dt.Decimal(precision=15, scale=2, nullable=False)), + ("p_comment", dt.string(nullable=False)), + ], + name="part", + ) + + +@pytest.fixture +def customer(): + return ibis.table( + [ + ("c_custkey", dt.int32(nullable=False)), + ("c_name", dt.string(nullable=False)), + ("c_address", dt.string(nullable=False)), + ("c_nationkey", dt.int32(nullable=False)), + ("c_phone", dt.string(nullable=False)), + ("c_acctbal", dt.Decimal(precision=15, scale=2, nullable=False)), + ("c_mktsegment", dt.string(nullable=False)), + ("c_comment", dt.string(nullable=False)), + ], + name="customer", + ) + + +@pytest.fixture +def supplier(): + return ibis.table( + [ + ("s_suppkey", dt.int32(nullable=False)), + ("s_name", dt.string(nullable=False)), + ("s_address", dt.string(nullable=False)), + ("s_nationkey", dt.int32(nullable=False)), + ("s_phone", dt.string(nullable=False)), + ("s_acctbal", dt.Decimal(precision=15, scale=2, nullable=False)), + ("s_comment", dt.string(nullable=False)), + ], + name="supplier", + ) + + +@pytest.fixture +def nation(): + return ibis.table( + [ + ("n_nationkey", dt.int32(nullable=False)), + ("n_name", dt.string(nullable=False)), + ("n_regionkey", dt.int32(nullable=False)), + ("n_comment", dt.string(nullable=False)), + ("n_suppkey", dt.int32(nullable=False)), + ], + name="nation", + ) + + +@pytest.fixture +def region(): + return ibis.table( + [ + ("r_regionkey", dt.int32(nullable=False)), + ("r_name", dt.string(nullable=False)), + ("r_comment", dt.string(nullable=False)), + ], + name="region", + ) + + +@pytest.fixture +def tpc_h01(lineitem): + return ( + lineitem.filter(lambda t: t.l_shipdate <= date(year=1998, month=9, day=2)) + .group_by(["l_returnflag", "l_linestatus"]) + .aggregate( + sum_qty=lambda t: t.l_quantity.sum(), + sum_base_price=lambda t: t.l_extendedprice.sum(), + sum_disc_price=lambda t: (t.l_extendedprice * (1 - t.l_discount)).sum(), + sum_charge=lambda t: ( + t.l_extendedprice * (1 - t.l_discount) * (1 + t.l_tax) + ).sum(), + avg_qty=lambda t: t.l_quantity.mean(), + avg_price=lambda t: t.l_extendedprice.mean(), + avg_disc=lambda t: t.l_discount.mean(), + count_order=lambda t: t.count(), + ) + .order_by(["l_returnflag", "l_linestatus"]) + ) + + +@pytest.fixture +def tpc_h02( + part, supplier, partsupp, nation, region, REGION="EUROPE", SIZE=25, TYPE="BRASS" +): + "Minimum Cost Supplier Query (Q2)" + + expr = ( + part.join(partsupp, part.p_partkey == partsupp.ps_partkey) + .join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) + .join(nation, supplier.s_nationkey == nation.n_nationkey) + .join(region, nation.n_regionkey == region.r_regionkey) + ) + + subexpr = ( + partsupp.join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) + .join(nation, supplier.s_nationkey == nation.n_nationkey) + .join(region, nation.n_regionkey == region.r_regionkey) + ) + + subexpr = subexpr.filter( + [(subexpr.r_name == REGION) & (expr.p_partkey == subexpr.ps_partkey)] + ) + + filters = [ + expr.p_size == SIZE, + expr.p_type.like("%" + TYPE), + expr.r_name == REGION, + expr.ps_supplycost == subexpr.ps_supplycost.min(), + ] + q = expr.filter(filters) + + q = q.select( + [ + q.s_acctbal, + q.s_name, + q.n_name, + q.p_partkey, + q.p_mfgr, + q.s_address, + q.s_phone, + q.s_comment, + ] + ) + + return q.order_by( + [ + ibis.desc(q.s_acctbal), + q.n_name, + q.s_name, + q.p_partkey, + ] + ).limit(100) + + +@pytest.fixture +def tpc_h03(customer, orders, lineitem): + DATE = "1995-03-15" + q = customer.join(orders, customer.c_custkey == orders.o_custkey) + q = q.join(lineitem, lineitem.l_orderkey == orders.o_orderkey) + q = q.filter( + [q.c_mktsegment == "BUILDING", q.o_orderdate < DATE, q.l_shipdate > DATE] + ) + qg = q.group_by([q.l_orderkey, q.o_orderdate, q.o_shippriority]) + q = qg.aggregate(revenue=(q.l_extendedprice * (1 - q.l_discount)).sum()) + q = q.order_by([ibis.desc(q.revenue), q.o_orderdate]) + q = q.limit(10) + return q + + +@pytest.fixture +def tpc_h04(orders, lineitem): + from ibis import _ + from ibis.expr.operations import ExistsSubquery + + lineitem_filtered = lineitem.filter( + [ + lineitem.l_orderkey == orders.o_orderkey, + lineitem.l_commitdate < lineitem.l_receiptdate, + ] + ) + cond_exists = ExistsSubquery(lineitem_filtered).to_expr() + + q = orders.filter( + [ + cond_exists, + orders.o_orderdate >= "1993-07-01", + orders.o_orderdate < "1993-10-01", + ] + ) + q = q.group_by([_.o_orderpriority]) + q = q.aggregate(order_count=_.count()) + q = q.order_by([_.o_orderpriority]) + return q + + +@pytest.fixture +def tpc_h05(customer, orders, lineitem, supplier, nation, region): + q = customer + q = q.join(orders, customer.c_custkey == orders.o_custkey) + q = q.join(lineitem, lineitem.l_orderkey == orders.o_orderkey) + q = q.join(supplier, lineitem.l_suppkey == supplier.s_suppkey) + q = q.join( + nation, + (customer.c_nationkey == supplier.s_nationkey) + & (supplier.s_nationkey == nation.n_nationkey), + ) + q = q.join(region, nation.n_regionkey == region.r_regionkey) + q = q.filter( + [ + q.r_name == "ASIA", + q.o_orderdate >= "1994-01-01", + q.o_orderdate < "1995-01-01", + ] + ) + revexpr = q.l_extendedprice * (1 - q.l_discount) + gq = q.group_by([q.n_name]) + q = gq.aggregate(revenue=revexpr.sum()) + q = q.order_by([ibis.desc(q.revenue)]) + return q + + +@pytest.fixture +def tpc_h06(lineitem): + q = lineitem + discount_min = round(0.06 - 0.01, 2) + discount_max = round(0.06 + 0.01, 2) + q = q.filter( + [ + q.l_shipdate >= "1994-01-01", + q.l_shipdate < "1995-01-01", + q.l_discount.between(discount_min, discount_max), + q.l_quantity < 24, + ] + ) + q = q.aggregate(revenue=(q.l_extendedprice * q.l_discount).sum()) + return q + + +@pytest.fixture +def tpc_h07(supplier, lineitem, orders, customer, nation): + q = supplier + q = q.join(lineitem, supplier.s_suppkey == lineitem.l_suppkey) + q = q.join(orders, orders.o_orderkey == lineitem.l_orderkey) + q = q.join(customer, customer.c_custkey == orders.o_custkey) + n1 = nation + n2 = nation.view() + q = q.join(n1, supplier.s_nationkey == n1.n_nationkey) + q = q.join(n2, customer.c_nationkey == n2.n_nationkey) + # q = q[ + # n1.n_name.name("supp_nation"), + # n2.n_name.name("cust_nation"), + # lineitem.l_shipdate, + # lineitem.l_extendedprice, + # lineitem.l_discount, + # lineitem.l_shipdate.year().cast("string").name("l_year"), + # (lineitem.l_extendedprice * (1 - lineitem.l_discount)).name("volume"), + # ] + + q = q.select( + { + "supp_nation": n1.n_name, + "cust_nation": n2.n_name, + "l_shipdate": lineitem.l_shipdate, + "l_extendedprice": lineitem.l_extendedprice, + "l_discount": lineitem.l_discount, + "l_year": lineitem.l_shipdate.year().cast("string"), + "volume": lineitem.l_extendedprice * (1 - lineitem.l_discount), + } + ) + + q = q.filter( + [ + ((q.cust_nation == "FRANCE") & (q.supp_nation == "GERMANY")) + | ((q.cust_nation == "GERMANY") & (q.supp_nation == "FRANCE")), + q.l_shipdate.between("1995-01-01", "1996-12-31"), + ] + ) + gq = q.group_by(["supp_nation", "cust_nation", "l_year"]) + q = gq.aggregate(revenue=q.volume.sum()) + q = q.order_by(["supp_nation", "cust_nation", "l_year"]) + return q + + +@pytest.fixture +def tpc_h08(part, supplier, lineitem, orders, customer, region, nation): + n1 = nation + n2 = n1.view() + q = part + q = q.join(lineitem, part.p_partkey == lineitem.l_partkey) + q = q.join(supplier, supplier.s_suppkey == lineitem.l_suppkey) + q = q.join(orders, lineitem.l_orderkey == orders.o_orderkey) + q = q.join(customer, orders.o_custkey == customer.c_custkey) + q = q.join(n1, customer.c_nationkey == n1.n_nationkey) + q = q.join(region, n1.n_regionkey == region.r_regionkey) + q = q.join(n2, supplier.s_suppkey == n2.n_suppkey) + + q = q.select( + [ + orders.o_orderdate.year().cast("string").name("o_year"), + (lineitem.l_extendedprice * (1 - lineitem.l_discount)).name("volume"), + n2.n_name.name("nation"), + region.r_name, + orders.o_orderdate, + part.p_type, + ] + ) + q = q.filter( + [ + q.r_name == "AMERICA", + q.o_orderdate.between("1995-01-01", "1996-12-31"), + q.p_type == "ECONOMY ANODIZED STEEL", + ] + ) + q = q.mutate( + nation_volume=ibis.case().when(q.nation == "BRAZIL", q.volume).else_(0).end() + ) + gq = q.group_by([q.o_year]) + q = gq.aggregate(nation_volume_sum=q.nation_volume.sum(), volume_sum=q.volume.sum()) + q = q.mutate(mkt_share=q.nation_volume_sum / q.volume_sum) + q = q.drop("nation_volume_sum", "volume_sum") + q = q.order_by([q.o_year]) + return q + + +@pytest.fixture +def tpc_h09(part, supplier, lineitem, partsupp, orders, nation): + q = lineitem + q = q.join(supplier, supplier.s_suppkey == lineitem.l_suppkey) + q = q.join( + partsupp, + (partsupp.ps_suppkey == lineitem.l_suppkey) + & (partsupp.ps_partkey == lineitem.l_partkey), + ) + q = q.join(part, part.p_partkey == lineitem.l_partkey) + q = q.join(orders, orders.o_orderkey == lineitem.l_orderkey) + q = q.join(nation, supplier.s_nationkey == nation.n_nationkey) + q = q.select( + { + "amount": q.l_extendedprice * (1 - q.l_discount) + - q.ps_supplycost * q.l_quantity, + "o_year": q.o_orderdate.year().cast("string"), + "nation": q.n_name, + "p_name": q.p_name, + } + ) + # q = q[ + # (q.l_extendedprice * (1 - q.l_discount) - q.ps_supplycost * q.l_quantity).name( + # "amount" + # ), + # q.o_orderdate.year().cast("string").name("o_year"), + # q.n_name.name("nation"), + # q.p_name, + # ] + q = q.filter([q.p_name.like("%GREEN%")]) + gq = q.group_by([q.nation, q.o_year]) + q = gq.aggregate(sum_profit=q.amount.sum()) + q = q.order_by([q.nation, ibis.desc(q.o_year)]) + return q + + +@pytest.fixture +def tpc_h10(customer, orders, lineitem, nation): + q = customer + q = q.join(orders, customer.c_custkey == orders.o_custkey) + q = q.join(lineitem, lineitem.l_orderkey == orders.o_orderkey) + q = q.join(nation, customer.c_nationkey == nation.n_nationkey) + + q = q.filter( + [ + (q.o_orderdate >= "1993-01-01") & (q.o_orderdate < "1993-04-01"), + q.l_returnflag == "R", + ] + ) + + gq = q.group_by( + [ + q.c_custkey, + q.c_name, + q.c_acctbal, + q.c_phone, + q.n_name, + q.c_address, + q.c_comment, + ] + ) + q = gq.aggregate(revenue=(q.l_extendedprice * (1 - q.l_discount)).sum()) + + q = q.order_by(ibis.desc(q.revenue)) + return q.limit(20) + + +@pytest.fixture +def tpc_h11(partsupp, supplier, nation): + q = partsupp + q = q.join(supplier, partsupp.ps_suppkey == supplier.s_suppkey) + q = q.join(nation, nation.n_nationkey == supplier.s_nationkey) + + q = q.filter([q.n_name == "GERMANY"]) + + innerq = partsupp + innerq = innerq.join(supplier, partsupp.ps_suppkey == supplier.s_suppkey) + innerq = innerq.join(nation, nation.n_nationkey == supplier.s_nationkey) + innerq = innerq.filter([innerq.n_name == "GERMANY"]) + innerq = innerq.aggregate(total=(innerq.ps_supplycost * innerq.ps_availqty).sum()) + + gq = q.group_by([q.ps_partkey]) + q = gq.aggregate(value=(q.ps_supplycost * q.ps_availqty).sum()) + q = q.filter([q.value > innerq.total * 0.0001]) + q = q.order_by(ibis.desc(q.value)) + return q + + +@pytest.fixture +def tpc_h12(orders, lineitem): + q = orders + q = q.join(lineitem, orders.o_orderkey == lineitem.l_orderkey) + + q = q.filter( + [ + q.l_shipmode.isin(["MAIL", "SHIP"]), + q.l_commitdate < q.l_receiptdate, + q.l_shipdate < q.l_commitdate, + q.l_receiptdate >= "1994-01-01", + q.l_receiptdate < "1995-01-01", + ] + ) + + gq = q.group_by([q.l_shipmode]) + q = gq.aggregate( + high_line_count=( + q.o_orderpriority.case() + .when("1-URGENT", 1) + .when("2-HIGH", 1) + .else_(0) + .end() + ).sum(), + low_line_count=( + q.o_orderpriority.case() + .when("1-URGENT", 0) + .when("2-HIGH", 0) + .else_(1) + .end() + ).sum(), + ) + q = q.order_by(q.l_shipmode) + + return q + + +@pytest.fixture +def tpc_h13(customer, orders): + innerq = customer + innerq = innerq.left_join( + orders, + (customer.c_custkey == orders.o_custkey) + & ~orders.o_comment.like("%special%requests%"), + ) + innergq = innerq.group_by([innerq.c_custkey]) + innerq = innergq.aggregate(c_count=innerq.o_orderkey.count()) + + gq = innerq.group_by([innerq.c_count]) + q = gq.aggregate(custdist=innerq.count()) + + q = q.order_by([ibis.desc(q.custdist), ibis.desc(q.c_count)]) + return q + + +@pytest.fixture +def tpc_h14(lineitem, part): + q = lineitem + q = q.join(part, lineitem.l_partkey == part.p_partkey) + q = q.filter([q.l_shipdate >= "1995-09-01", q.l_shipdate < "1995-10-01"]) + + revenue = q.l_extendedprice * (1 - q.l_discount) + promo_revenue = q.p_type.like("PROMO%").ifelse(revenue, 0) + + q = q.aggregate(promo_revenue_sum=promo_revenue.sum(), revenue_sum=revenue.sum()) + q = q.mutate(promo_revenue=100 * q.promo_revenue_sum / q.revenue_sum) + q = q.drop("promo_revenue_sum", "revenue_sum") + return q + + +@pytest.fixture +def tpc_h15(lineitem, supplier): + qrev = lineitem + qrev = qrev.filter( + [lineitem.l_shipdate >= "1996-01-01", lineitem.l_shipdate < "1996-04-01"] + ) + + gqrev = qrev.group_by([lineitem.l_suppkey]) + qrev = gqrev.aggregate( + total_revenue=(qrev.l_extendedprice * (1 - qrev.l_discount)).sum() + ) + + q = supplier.join(qrev, supplier.s_suppkey == qrev.l_suppkey) + q = q.filter([q.total_revenue == qrev.total_revenue.max()]) + q = q.order_by([q.s_suppkey]) + q = q.select([q.s_suppkey, q.s_name, q.s_address, q.s_phone, q.total_revenue]) + return q + + +@pytest.fixture +def tpc_h16(partsupp, part, supplier): + q = partsupp.join(part, part.p_partkey == partsupp.ps_partkey) + q = q.filter( + [ + q.p_brand != "Brand#45", + ~q.p_type.like("MEDIUM POLISHED%"), + q.p_size.isin((49, 14, 23, 45, 19, 3, 36, 9)), + ~q.ps_suppkey.isin( + supplier.filter( + [supplier.s_comment.like("%Customer%Complaints%")] + ).s_suppkey + ), + ] + ) + gq = q.group_by([q.p_brand, q.p_type, q.p_size]) + q = gq.aggregate(supplier_cnt=q.ps_suppkey.nunique()) + q = q.order_by([ibis.desc(q.supplier_cnt), q.p_brand, q.p_type, q.p_size]) + return q + + +@pytest.fixture +def tpc_h17(lineitem, part): + q = lineitem.join(part, part.p_partkey == lineitem.l_partkey) + + innerq = lineitem + innerq = innerq.filter([innerq.l_partkey == q.p_partkey]) + + q = q.filter( + [ + q.p_brand == "Brand#23", + q.p_container == "MED BOX", + q.l_quantity < (0.2 * innerq.l_quantity.mean()), + ] + ) + q = q.aggregate(avg_yearly=q.l_extendedprice.sum()) + q = q.mutate(avg_yearly=q.avg_yearly / 0.7) + return q + + +@pytest.fixture +def tpc_h18(customer, orders, lineitem): + subgq = lineitem.group_by([lineitem.l_orderkey]) + subq = subgq.aggregate(qty_sum=lineitem.l_quantity.sum()) + subq = subq.filter([subq.qty_sum > 300]) + + q = customer + q = q.join(orders, customer.c_custkey == orders.o_custkey) + q = q.join(lineitem, orders.o_orderkey == lineitem.l_orderkey) + q = q.filter([q.o_orderkey.isin(subq.l_orderkey)]) + + gq = q.group_by( + [q.c_name, q.c_custkey, q.o_orderkey, q.o_orderdate, q.o_totalprice] + ) + q = gq.aggregate(sum_qty=q.l_quantity.sum()) + q = q.order_by([ibis.desc(q.o_totalprice), q.o_orderdate]) + return q.limit(100) + + +@pytest.fixture +def tpc_h19(lineitem, part): + q = lineitem.join(part, part.p_partkey == lineitem.l_partkey) + + q1 = ( + (q.p_brand == "Brand#12") + & (q.p_container.isin(("SM CASE", "SM BOX", "SM PACK", "SM PKG"))) + & (q.l_quantity >= 1) + & (q.l_quantity <= 1 + 10) + & (q.p_size.between(1, 5)) + & (q.l_shipmode.isin(("AIR", "AIR REG"))) + & (q.l_shipinstruct == "DELIVER IN PERSON") + ) + + q2 = ( + (q.p_brand == "Brand#23") + & (q.p_container.isin(("MED BAG", "MED BOX", "MED PKG", "MED PACK"))) + & (q.l_quantity >= 10) + & (q.l_quantity <= 10 + 10) + & (q.p_size.between(1, 10)) + & (q.l_shipmode.isin(("AIR", "AIR REG"))) + & (q.l_shipinstruct == "DELIVER IN PERSON") + ) + + q3 = ( + (q.p_brand == "Brand#34") + & (q.p_container.isin(("LG CASE", "LG BOX", "LG PACK", "LG PKG"))) + & (q.l_quantity >= 20) + & (q.l_quantity <= 20 + 10) + & (q.p_size.between(1, 15)) + & (q.l_shipmode.isin(("AIR", "AIR REG"))) + & (q.l_shipinstruct == "DELIVER IN PERSON") + ) + + q = q.filter([q1 | q2 | q3]) + q = q.aggregate(revenue=(q.l_extendedprice * (1 - q.l_discount)).sum()) + return q + + +@pytest.fixture +def tpc_h20(supplier, nation, partsupp, part, lineitem): + q1 = supplier.join(nation, supplier.s_nationkey == nation.n_nationkey) + + q3 = part.filter([part.p_name.like("forest%")]) + q2 = partsupp + + q4 = lineitem.filter( + [ + lineitem.l_partkey == q2.ps_partkey, + lineitem.l_suppkey == q2.ps_suppkey, + lineitem.l_shipdate >= "1994-01-01", + lineitem.l_shipdate < "1995-01-01", + ] + ) + + q2 = q2.filter( + [ + partsupp.ps_partkey.isin(q3.p_partkey), + partsupp.ps_availqty > 0.5 * q4.l_quantity.sum(), + ] + ) + + q1 = q1.filter([q1.n_name == "CANADA", q1.s_suppkey.isin(q2.ps_suppkey)]) + + q1 = q1.select([q1.s_name, q1.s_address]) + + return q1.order_by(q1.s_name) + + +@pytest.fixture +def tpc_h21(supplier, lineitem, orders, nation): + L2 = lineitem.view() + L3 = lineitem.view() + + q = supplier + q = q.join(lineitem, supplier.s_suppkey == lineitem.l_suppkey) + q = q.join(orders, orders.o_orderkey == lineitem.l_orderkey) + q = q.join(nation, supplier.s_nationkey == nation.n_nationkey) + q = q.select( + [ + q.l_orderkey.name("l1_orderkey"), + q.o_orderstatus, + q.l_receiptdate, + q.l_commitdate, + q.l_suppkey.name("l1_suppkey"), + q.s_name, + q.n_name, + ] + ) + q = q.filter( + [ + q.o_orderstatus == "F", + q.l_receiptdate > q.l_commitdate, + q.n_name == "SAUDI ARABIA", + ((L2.l_orderkey == q.l1_orderkey) & (L2.l_suppkey != q.l1_suppkey)).any(), + ~( + ( + (L3.l_orderkey == q.l1_orderkey) + & (L3.l_suppkey != q.l1_suppkey) + & (L3.l_receiptdate > L3.l_commitdate) + ).any() + ), + ] + ) + + gq = q.group_by([q.s_name]) + q = gq.aggregate(numwait=q.count()) + q = q.order_by([ibis.desc(q.numwait), q.s_name]) + return q.limit(100) + + +@pytest.fixture +def tpc_h22(customer, orders): + q = customer.filter( + [ + customer.c_acctbal > 0.00, + customer.c_phone.substr(0, 2).isin( + ("13", "31", "23", "29", "30", "18", "17") + ), + ] + ) + q = q.aggregate(avg_bal=customer.c_acctbal.mean()) + + custsale = customer.filter( + [ + customer.c_phone.substr(0, 2).isin( + ("13", "31", "23", "29", "30", "18", "17") + ), + customer.c_acctbal > q.avg_bal, + ~(orders.o_custkey == customer.c_custkey).any(), + ] + ) + custsale = custsale.select( + [customer.c_phone.substr(0, 2).name("cntrycode"), customer.c_acctbal] + ) + + gq = custsale.group_by(custsale.cntrycode) + outerq = gq.aggregate(numcust=custsale.count(), totacctbal=custsale.c_acctbal.sum()) + + return outerq.order_by(outerq.cntrycode) + + +@pytest.fixture +def compiler(): + from letsql.ibis_yaml.compiler import IbisYamlCompiler + + return IbisYamlCompiler() diff --git a/python/letsql/ibis_yaml/tests/test_arithmetic.py b/python/letsql/ibis_yaml/tests/test_arithmetic.py new file mode 100644 index 00000000..4722f680 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_arithmetic.py @@ -0,0 +1,96 @@ +import ibis + + +def test_add(compiler): + lit1 = ibis.literal(5) + lit2 = ibis.literal(3) + expr = lit1 + lit2 + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Add" + assert yaml_dict["args"][0]["op"] == "Literal" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["op"] == "Literal" + assert yaml_dict["args"][1]["value"] == 3 + assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_subtract(compiler): + lit1 = ibis.literal(5) + lit2 = ibis.literal(3) + expr = lit1 - lit2 + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Subtract" + assert yaml_dict["args"][0]["op"] == "Literal" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["op"] == "Literal" + assert yaml_dict["args"][1]["value"] == 3 + assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_multiply(compiler): + lit1 = ibis.literal(5) + lit2 = ibis.literal(3) + expr = lit1 * lit2 + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Multiply" + assert yaml_dict["args"][0]["op"] == "Literal" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["op"] == "Literal" + assert yaml_dict["args"][1]["value"] == 3 + assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_divide(compiler): + lit1 = ibis.literal(6.0) + lit2 = ibis.literal(2.0) + expr = lit1 / lit2 + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Divide" + assert yaml_dict["args"][0]["op"] == "Literal" + assert yaml_dict["args"][0]["value"] == 6.0 + assert yaml_dict["args"][1]["op"] == "Literal" + assert yaml_dict["args"][1]["value"] == 2.0 + assert yaml_dict["type"] == {"name": "Float64", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_mixed_arithmetic(compiler): + i = ibis.literal(5) + f = ibis.literal(2.5) + expr = i * f + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Multiply" + assert yaml_dict["type"] == {"name": "Float64", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_complex_arithmetic(compiler): + a = ibis.literal(10) + b = ibis.literal(5) + c = ibis.literal(2.0) + expr = (a + b) * c + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Multiply" + assert yaml_dict["args"][0]["op"] == "Add" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_basic.py b/python/letsql/ibis_yaml/tests/test_basic.py new file mode 100644 index 00000000..a8b40e59 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_basic.py @@ -0,0 +1,156 @@ +import datetime +import decimal + +import ibis + + +def test_unbound_table(t, compiler): + yaml_dict = compiler.compile_to_yaml(t) + assert yaml_dict["op"] == "UnboundTable" + assert yaml_dict["name"] == "test_table" + assert yaml_dict["schema"]["a"] == {"name": "Int64", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.schema() == t.schema() + assert roundtrip_expr.op().name == t.op().name + + +def test_field(t, compiler): + expr = t.a + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Field" + assert yaml_dict["name"] == "a" + assert yaml_dict["type"] == {"name": "Int64", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + assert roundtrip_expr.get_name() == expr.get_name() + + +def test_literal(compiler): + lit = ibis.literal(42) + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["value"] == 42 + assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + + +def test_binary_op(t, compiler): + expr = t.a + 1 + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Add" + assert yaml_dict["args"][0]["op"] == "Field" + assert yaml_dict["args"][1]["op"] == "Literal" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_primitive_types(compiler): + primitives = [ + (ibis.literal(True), "Boolean"), + (ibis.literal(1), "Int8"), + (ibis.literal(1000), "Int16"), + (ibis.literal(1.0), "Float64"), + (ibis.literal("hello"), "String"), + (ibis.literal(None), "Null"), + ] + for lit, expected_type in primitives: + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["type"]["name"] == expected_type + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + assert roundtrip_expr.type().name == lit.type().name + + +def test_temporal_types(compiler): + now = datetime.datetime.now() + today = datetime.date.today() + time = datetime.time(12, 0) + temporals = [ + (ibis.literal(now), "Timestamp"), + (ibis.literal(today), "Date"), + (ibis.literal(time), "Time"), + ] + for lit, expected_type in temporals: + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["type"]["name"] == expected_type + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + assert roundtrip_expr.type().name == lit.type().name + + +def test_decimal_type(compiler): + dec = decimal.Decimal("123.45") + lit = ibis.literal(dec) + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["type"]["name"] == "Decimal" + assert yaml_dict["type"]["nullable"] + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + assert roundtrip_expr.type().name == lit.type().name + + +def test_array_type(compiler): + lit = ibis.literal([1, 2, 3]) + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["type"]["name"] == "Array" + assert yaml_dict["type"]["value_type"]["name"] == "Int8" + assert yaml_dict["value"] == (1, 2, 3) + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + assert roundtrip_expr.type().value_type == lit.type().value_type + + +def test_map_type(compiler): + lit = ibis.literal({"a": 1, "b": 2}) + yaml_dict = compiler.compile_to_yaml(lit) + assert yaml_dict["op"] == "Literal" + assert yaml_dict["type"]["name"] == "Map" + assert yaml_dict["type"]["key_type"]["name"] == "String" + assert yaml_dict["type"]["value_type"]["name"] == "Int8" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(lit) + assert roundtrip_expr.type().key_type == lit.type().key_type + assert roundtrip_expr.type().value_type == lit.type().value_type + + +def test_complex_expression_roundtrip(t, compiler): + expr = (t.a + 1).abs() * 2 + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_window_function_roundtrip(t, compiler): + expr = t.a.sum().over(ibis.window(group_by=t.a)) + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_join_roundtrip(t, compiler): + t2 = ibis.table({"b": "int64"}, name="test_table_2") + expr = t.join(t2, t.a == t2.b) + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.schema() == expr.schema() + + +def test_aggregation_roundtrip(t, compiler): + expr = t.group_by(t.a).aggregate(count=t.a.count()) + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.schema() == expr.schema() diff --git a/python/letsql/ibis_yaml/tests/test_join_chain.py b/python/letsql/ibis_yaml/tests/test_join_chain.py new file mode 100644 index 00000000..3fffe466 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_join_chain.py @@ -0,0 +1,101 @@ +import ibis +import pytest + +from letsql.ibis_yaml.compiler import IbisYamlCompiler + + +@pytest.fixture +def orders(): + return ibis.table( + { + "o_orderkey": "int32", + "o_orderpriority": "string", + "o_custkey": "int32", + "o_orderdate": "date", + }, + name="orders", + ) + + +@pytest.fixture +def supplier(): + return ibis.table( + { + "s_suppkey": "int32", + "s_nationkey": "int32", + }, + name="supplier", + ) + + +@pytest.fixture +def lineitem(): + return ibis.table( + { + "l_orderkey": "int64", + "l_suppkey": "int32", + "l_shipdate": "date", + "l_extendedprice": "decimal(15,2)", + "l_discount": "decimal(15,2)", + }, + name="lineitem", + ) + + +@pytest.fixture +def customer(): + return ibis.table( + { + "c_custkey": "int32", + "c_nationkey": "int32", + }, + name="customer", + ) + + +@pytest.fixture +def nation(): + return ibis.table( + { + "n_nationkey": "int32", + "n_name": "string", + }, + name="nation", + ) + + +# Minimal test mimicking h07 join chain with self-reference in projection +def test_minimal_joinchain_self_reference( + compiler, orders, supplier, lineitem, customer, nation +): + q = supplier.join(lineitem, supplier.s_suppkey == lineitem.l_suppkey) + q = q.join(orders, orders.o_orderkey == lineitem.l_orderkey) + q = q.join(customer, customer.c_custkey == orders.o_custkey) + n1 = nation + n2 = nation.view() + q = q.join(n1, supplier.s_nationkey == n1.n_nationkey) + q = q.join(n2, customer.c_nationkey == n2.n_nationkey) + q = q.projection( + { + "supp_nation": n1.n_name, + "cust_nation": n2.n_name, + "l_shipdate": lineitem.l_shipdate, + "l_extendedprice": lineitem.l_extendedprice, + "l_discount": lineitem.l_discount, + } + ) + q = q.filter( + ( + ((q.cust_nation == "FRANCE") & (q.supp_nation == "GERMANY")) + | ((q.cust_nation == "GERMANY") & (q.supp_nation == "FRANCE")) + ) + ) + + compiler = IbisYamlCompiler() + yaml_dict = compiler.compile_to_yaml(q) + q_roundtrip = compiler.compile_from_yaml(yaml_dict) + + try: + _ = q_roundtrip["cust_nation"] + except Exception as e: + pytest.fail(f"Accessing 'cust_nation' on the roundtrip query failed: {e}") diff --git a/python/letsql/ibis_yaml/tests/test_operations_boolean.py b/python/letsql/ibis_yaml/tests/test_operations_boolean.py new file mode 100644 index 00000000..2e0bd936 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_operations_boolean.py @@ -0,0 +1,104 @@ +import ibis + + +def test_equals(compiler): + a = ibis.literal(5) + b = ibis.literal(5) + expr = a == b + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Equals" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["value"] == 5 + assert yaml_dict["type"] == {"name": "Boolean", "nullable": True} + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_not_equals(compiler): + a = ibis.literal(5) + b = ibis.literal(3) + expr = a != b + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "NotEquals" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["value"] == 3 + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_greater_than(compiler): + a = ibis.literal(5) + b = ibis.literal(3) + expr = a > b + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Greater" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["value"] == 3 + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_less_than(compiler): + a = ibis.literal(3) + b = ibis.literal(5) + expr = a < b + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Less" + assert yaml_dict["args"][0]["value"] == 3 + assert yaml_dict["args"][1]["value"] == 5 + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_and_or(compiler): + a = ibis.literal(5) + b = ibis.literal(3) + c = ibis.literal(10) + + expr_and = (a > b) & (a < c) + yaml_dict = compiler.compile_to_yaml(expr_and) + assert yaml_dict["op"] == "And" + assert yaml_dict["args"][0]["op"] == "Greater" + assert yaml_dict["args"][1]["op"] == "Less" + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr_and) + + expr_or = (a > b) | (a < c) + yaml_dict = compiler.compile_to_yaml(expr_or) + assert yaml_dict["op"] == "Or" + assert yaml_dict["args"][0]["op"] == "Greater" + assert yaml_dict["args"][1]["op"] == "Less" + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr_or) + + +def test_not(compiler): + a = ibis.literal(True) + expr = ~a + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Not" + assert yaml_dict["args"][0]["value"] + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_is_null(compiler): + a = ibis.literal(None) + expr = a.isnull() + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "IsNull" + assert yaml_dict["args"][0]["value"] is None + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_between(compiler): + a = ibis.literal(5) + expr = a.between(3, 7) + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Between" + assert yaml_dict["args"][0]["value"] == 5 + assert yaml_dict["args"][1]["value"] == 3 + assert yaml_dict["args"][2]["value"] == 7 + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_cast.py b/python/letsql/ibis_yaml/tests/test_operations_cast.py new file mode 100644 index 00000000..4039ec26 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_operations_cast.py @@ -0,0 +1,53 @@ +import ibis + + +def test_explicit_cast(compiler): + expr = ibis.literal(42).cast("float64") + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "Cast" + assert yaml_dict["args"][0]["op"] == "Literal" + assert yaml_dict["args"][0]["value"] == 42 + assert yaml_dict["type"]["name"] == "Float64" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_implicit_cast(compiler): + i = ibis.literal(1) + f = ibis.literal(2.5) + expr = i + f + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "Add" + assert yaml_dict["args"][0]["type"]["name"] == "Int8" + assert yaml_dict["args"][1]["type"]["name"] == "Float64" + assert yaml_dict["type"]["name"] == "Float64" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_string_cast(compiler): + expr = ibis.literal("42").cast("int64") + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "Cast" + assert yaml_dict["args"][0]["value"] == "42" + assert yaml_dict["type"]["name"] == "Int64" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_timestamp_cast(compiler): + expr = ibis.literal("2024-01-01").cast("timestamp") + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "Cast" + assert yaml_dict["args"][0]["value"] == "2024-01-01" + assert yaml_dict["type"]["name"] == "Timestamp" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_datetime.py b/python/letsql/ibis_yaml/tests/test_operations_datetime.py new file mode 100644 index 00000000..34a3c063 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_operations_datetime.py @@ -0,0 +1,105 @@ +"""Test datetime operation translations.""" + +from datetime import datetime + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.operations.temporal as tm + + +def test_date_extract(compiler): + dt_expr = ibis.literal(datetime(2024, 3, 14, 15, 9, 26)) + + year = dt_expr.year() + year_yaml = compiler.compile_to_yaml(year) + assert year_yaml["op"] == "ExtractYear" + assert year_yaml["args"][0]["value"] == "2024-03-14T15:09:26" + assert year_yaml["type"]["name"] == "Int32" + roundtrip_year = compiler.compile_from_yaml(year_yaml) + assert roundtrip_year.equals(year) + + month = dt_expr.month() + month_yaml = compiler.compile_to_yaml(month) + assert month_yaml["op"] == "ExtractMonth" + roundtrip_month = compiler.compile_from_yaml(month_yaml) + assert roundtrip_month.equals(month) + + day = dt_expr.day() + day_yaml = compiler.compile_to_yaml(day) + assert day_yaml["op"] == "ExtractDay" + roundtrip_day = compiler.compile_from_yaml(day_yaml) + assert roundtrip_day.equals(day) + + +def test_time_extract(compiler): + dt_expr = ibis.literal(datetime(2024, 3, 14, 15, 9, 26)) + + hour = dt_expr.hour() + hour_yaml = compiler.compile_to_yaml(hour) + assert hour_yaml["op"] == "ExtractHour" + assert hour_yaml["args"][0]["value"] == "2024-03-14T15:09:26" + assert hour_yaml["type"]["name"] == "Int32" + roundtrip_hour = compiler.compile_from_yaml(hour_yaml) + assert roundtrip_hour.equals(hour) + + minute = dt_expr.minute() + minute_yaml = compiler.compile_to_yaml(minute) + assert minute_yaml["op"] == "ExtractMinute" + roundtrip_minute = compiler.compile_from_yaml(minute_yaml) + assert roundtrip_minute.equals(minute) + + second = dt_expr.second() + second_yaml = compiler.compile_to_yaml(second) + assert second_yaml["op"] == "ExtractSecond" + roundtrip_second = compiler.compile_from_yaml(second_yaml) + assert roundtrip_second.equals(second) + + +def test_timestamp_arithmetic(compiler): + ts = ibis.literal(datetime(2024, 3, 14, 15, 9, 26)) + delta = ibis.interval(days=1) + + plus_day = ts + delta + yaml_dict = compiler.compile_to_yaml(plus_day) + assert yaml_dict["op"] == "TimestampAdd" + assert yaml_dict["type"]["name"] == "Timestamp" + assert yaml_dict["args"][1]["type"]["name"] == "Interval" + roundtrip_plus = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_plus.equals(plus_day) + + minus_day = ts - delta + yaml_dict = compiler.compile_to_yaml(minus_day) + assert yaml_dict["op"] == "TimestampSub" + assert yaml_dict["type"]["name"] == "Timestamp" + assert yaml_dict["args"][1]["type"]["name"] == "Interval" + roundtrip_minus = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_minus.equals(minus_day) + + +def test_timestamp_diff(compiler): + ts1 = ibis.literal(datetime(2024, 3, 14)) + ts2 = ibis.literal(datetime(2024, 3, 15)) + diff = ts2 - ts1 + yaml_dict = compiler.compile_to_yaml(diff) + assert yaml_dict["op"] == "TimestampDiff" + assert yaml_dict["type"]["name"] == "Interval" + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(diff) + + +def test_temporal_unit_yaml(compiler): + interval_date = ibis.literal(5, type=dt.Interval(unit=tm.DateUnit("D"))) + yaml_date = compiler.compile_to_yaml(interval_date) + assert yaml_date["type"]["name"] == "Interval" + assert yaml_date["type"]["unit"]["name"] == "DateUnit" + assert yaml_date["type"]["unit"]["value"] == "D" + roundtrip_date = compiler.compile_from_yaml(yaml_date) + assert roundtrip_date.equals(interval_date) + + interval_time = ibis.literal(10, type=dt.Interval(unit=tm.TimeUnit("h"))) + yaml_time = compiler.compile_to_yaml(interval_time) + assert yaml_time["type"]["name"] == "Interval" + assert yaml_time["type"]["unit"]["name"] == "TimeUnit" + assert yaml_time["type"]["unit"]["value"] == "h" + roundtrip_time = compiler.compile_from_yaml(yaml_time) + assert roundtrip_time.equals(interval_time) diff --git a/python/letsql/ibis_yaml/tests/test_relations.py b/python/letsql/ibis_yaml/tests/test_relations.py new file mode 100644 index 00000000..c14a5799 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_relations.py @@ -0,0 +1,86 @@ +import ibis + + +def test_filter(compiler, t): + expr = t.filter(t.a > 0) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "Filter" + assert yaml_dict["predicates"][0]["op"] == "Greater" + assert yaml_dict["parent"]["op"] == "UnboundTable" + + # Roundtrip test: compile from YAML and verify equality + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_projection(compiler, t): + expr = t.select(["a", "b"]) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "Project" + assert yaml_dict["parent"]["op"] == "UnboundTable" + assert set(yaml_dict["values"]) == {"a", "b"} + + # Roundtrip test + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_aggregation(compiler, t): + expr = t.group_by("a").aggregate(avg_c=t.c.mean()) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "Aggregate" + assert yaml_dict["by"][0]["name"] == "a" + assert yaml_dict["metrics"]["avg_c"]["op"] == "Mean" + + # Roundtrip test + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_join(compiler): + t1 = ibis.table(dict(a="int", b="string"), name="t1") + t2 = ibis.table(dict(b="string", c="float"), name="t2") + expr = t1.join(t2, t1.b == t2.b) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "JoinChain" + # The first join link's predicates + assert yaml_dict["rest"][0]["predicates"][0]["op"] == "Equals" + assert yaml_dict["rest"][0]["how"] == "inner" + + # Roundtrip test + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_order_by(compiler, t): + expr = t.order_by(["a", "b"]) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "Sort" + assert len(yaml_dict["keys"]) == 2 + + # Roundtrip test + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_limit(compiler, t): + expr = t.limit(10) + yaml_dict = compiler.compile_to_yaml(expr) + + # Original assertions + assert yaml_dict["op"] == "Limit" + assert yaml_dict["n"] == 10 + + # Roundtrip test + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_selection.py b/python/letsql/ibis_yaml/tests/test_selection.py new file mode 100644 index 00000000..486bb859 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_selection.py @@ -0,0 +1,13 @@ +import ibis + + +def test_selection_on_view(compiler): + T = ibis.table({"id": "int32", "name": "string"}, name="T") + T_view = T.view() + q = T.join(T_view, T.id == T_view.id) + q = q.select({"alias_name": T_view.name}) + q = q.filter(q.alias_name == "X") + + yaml_dict = compiler.compile_to_yaml(q) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(q) diff --git a/python/letsql/ibis_yaml/tests/test_string_ops.py b/python/letsql/ibis_yaml/tests/test_string_ops.py new file mode 100644 index 00000000..74adbfac --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_string_ops.py @@ -0,0 +1,48 @@ +import ibis + + +def test_string_concat(compiler): + s1 = ibis.literal("hello") + s2 = ibis.literal("world") + expr = s1 + s2 + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "StringConcat" + assert yaml_dict["args"][0]["value"] == "hello" + assert yaml_dict["args"][1]["value"] == "world" + assert yaml_dict["type"] == {"name": "String", "nullable": True} + + +def test_string_upper_lower(compiler): + s = ibis.literal("Hello") + upper_expr = s.upper() + lower_expr = s.lower() + + upper_yaml = compiler.compile_to_yaml(upper_expr) + assert upper_yaml["op"] == "Uppercase" + assert upper_yaml["args"][0]["value"] == "Hello" + + lower_yaml = compiler.compile_to_yaml(lower_expr) + assert lower_yaml["op"] == "Lowercase" + assert lower_yaml["args"][0]["value"] == "Hello" + + +def test_string_length(compiler): + s = ibis.literal("hello") + expr = s.length() + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "StringLength" + assert yaml_dict["args"][0]["value"] == "hello" + assert yaml_dict["type"] == {"name": "Int32", "nullable": True} + + +def test_string_substring(compiler): + s = ibis.literal("hello world") + expr = s.substr(0, 5) + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "Substring" + assert yaml_dict["args"][0]["value"] == "hello world" + assert yaml_dict["args"][1]["value"] == 0 + assert yaml_dict["args"][2]["value"] == 5 diff --git a/python/letsql/ibis_yaml/tests/test_subquery.py b/python/letsql/ibis_yaml/tests/test_subquery.py new file mode 100644 index 00000000..961e68a8 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_subquery.py @@ -0,0 +1,42 @@ +import ibis +import ibis.expr.operations as ops + + +def test_scalar_subquery(compiler, t): + expr = ops.ScalarSubquery(t.c.mean().as_table()).to_expr() + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "ScalarSubquery" + assert yaml_dict["args"][0]["op"] == "Aggregate" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_exists_subquery(compiler): + t1 = ibis.table(dict(a="int", b="string"), name="t1") + t2 = ibis.table(dict(a="int", c="float"), name="t2") + + filtered = t2.filter(t2.a == t1.a) + expr = ops.ExistsSubquery(filtered).to_expr() + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "ExistsSubquery" + assert yaml_dict["rel"]["op"] == "Filter" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) + + +def test_in_subquery(compiler): + t1 = ibis.table(dict(a="int", b="string"), name="t1") + t2 = ibis.table(dict(a="int", c="float"), name="t2") + + expr = ops.InSubquery(t1.select("a"), t2.a).to_expr() + yaml_dict = compiler.compile_to_yaml(expr) + + assert yaml_dict["op"] == "InSubquery" + assert yaml_dict["type"]["name"] == "Boolean" + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_tpch.py b/python/letsql/ibis_yaml/tests/test_tpch.py new file mode 100644 index 00000000..ecab4a92 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_tpch.py @@ -0,0 +1,50 @@ +import pytest + +import letsql.ibis_yaml +import letsql.ibis_yaml.utils + + +TPC_H = [ + "tpc_h01", + "tpc_h02", + "tpc_h03", + "tpc_h04", + "tpc_h05", + "tpc_h06", + "tpc_h07", + "tpc_h08", + "tpc_h09", + "tpc_h10", + "tpc_h11", + "tpc_h12", + "tpc_h13", + "tpc_h14", + "tpc_h15", + "tpc_h16", + "tpc_h17", + "tpc_h18", + "tpc_h19", + "tpc_h20", + "tpc_h21", + "tpc_h22", +] + + +@pytest.mark.parametrize("fixture_name", TPC_H) +def test_yaml_roundtrip(fixture_name, compiler, request): + compiler = letsql.ibis_yaml.compiler.IbisYamlCompiler() + query = request.getfixturevalue(fixture_name) + + yaml_dict = compiler.compile_to_yaml(query) + print("Original Query:") + print(query) + + roundtrip_query = compiler.compile_from_yaml(yaml_dict) + print("Roundtrip Query:") + print(roundtrip_query) + + letsql.ibis_yaml.utils.diff_ibis_exprs(query, roundtrip_query) + + assert roundtrip_query.equals(query), ( + f"Roundtrip expression for {fixture_name} does not match the original." + ) diff --git a/python/letsql/ibis_yaml/tests/test_udf.py b/python/letsql/ibis_yaml/tests/test_udf.py new file mode 100644 index 00000000..52696d3e --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_udf.py @@ -0,0 +1,59 @@ +import ibis +import pytest + +import letsql.ibis_yaml +import letsql.ibis_yaml.utils + + +def test_built_in_udf_properties(compiler): + t = ibis.table({"a": "int64"}, name="t") + + @ibis.udf.scalar.builtin + def add_one(x: int) -> int: + return x + 1 + + expr = t.mutate(new=add_one(t.a)) + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + + original_mutation = expr.op() + roundtrip_mutation = roundtrip_expr.op() + + original_udf = original_mutation.values["new"] + roundtrip_udf = roundtrip_mutation.values["new"] + + assert original_udf.__func_name__ == roundtrip_udf.__func_name__ + assert original_udf.__input_type__ == roundtrip_udf.__input_type__ + assert original_udf.dtype == roundtrip_udf.dtype + assert len(original_udf.args) == len(roundtrip_udf.args) + + for orig_arg, rt_arg in zip(original_udf.args, roundtrip_udf.args): + assert orig_arg.dtype == rt_arg.dtype + + +@pytest.mark.xfail( + reason="UDFs do not have the same memory address when pickled/unpickled" +) +def test_built_in_udf(compiler): + # (Pdb) diffs[3][2].args[0] == diffs[3][1].args[0] + # False + # (Pdb) diffs[3][2].args[0] + # + # (Pdb) diffs[3][2].args[0].args + # (, {'a': , 'new': }) + # (Pdb) diffs[3][1].args[0].args + # (, {'a': , 'new': }) + t = ibis.table({"a": "int64"}, name="t") + + @ibis.udf.scalar.builtin + def add_one(x: int) -> int: + pass + + expr = t.mutate(new=add_one(t.a)) + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + print(f"Original {expr}") + print(f"Roundtrip {roundtrip_expr}") + letsql.ibis_yaml.utils.diff_ibis_exprs(expr, roundtrip_expr) + + assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/letsql/ibis_yaml/tests/test_window_functions.py new file mode 100644 index 00000000..805aa52d --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_window_functions.py @@ -0,0 +1,57 @@ +import ibis + + +def test_window_function_roundtrip(compiler, t): + expr = t.select( + [ + t.c.mean() + .over(ibis.window(preceding=5, following=0, group_by=t.a)) + .name("mean_c") + ] + ) + + yaml_dict = compiler.compile_to_yaml(expr) + + reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + + assert expr.equals(reconstructed_expr) + + +def test_aggregation_window(compiler, t): + cases = [ + (None, None), + (5, 0), + (0, 5), + (5, 5), + ] + + for preceding, following in cases: + expr = t.select( + [ + t.c.mean() + .over( + ibis.window(preceding=preceding, following=following, group_by=t.a) + ) + .name("mean_c") + ] + ) + + yaml_dict = compiler.compile_to_yaml(expr) + assert yaml_dict["op"] == "Project" + window_func = yaml_dict["values"]["mean_c"] + assert window_func["op"] == "WindowFunction" + assert window_func["args"][0]["op"] == "Mean" + + if preceding is None: + assert "start" not in window_func + else: + assert window_func["start"] == preceding + + if following is None: + assert "end" not in window_func + else: + assert window_func["end"] == following + + print(yaml_dict) + + assert window_func["group_by"][0]["name"] == "a" diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py new file mode 100644 index 00000000..89ebe7ae --- /dev/null +++ b/python/letsql/ibis_yaml/translate.py @@ -0,0 +1,1157 @@ +from __future__ import annotations + +import datetime +import decimal +import functools +from typing import Any + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import ibis.expr.operations.temporal as tm +import ibis.expr.rules as rlz +import ibis.expr.types as ir +from ibis.common.annotations import Argument +from ibis.common.exceptions import IbisTypeError + +from letsql.ibis_yaml.utils import ( + deserialize_udf_function, + freeze, + serialize_udf_function, +) + + +FROM_YAML_HANDLERS: dict[str, Any] = {} + + +def register_from_yaml_handler(*op_names: str): + def decorator(func): + for name in op_names: + FROM_YAML_HANDLERS[name] = func + return func + + return decorator + + +@functools.cache +@functools.singledispatch +def translate_from_yaml(yaml_dict: dict, compiler: Any) -> Any: + op_type = yaml_dict["op"] + if op_type not in FROM_YAML_HANDLERS: + raise NotImplementedError(f"No handler for operation {op_type}") + return FROM_YAML_HANDLERS[op_type](yaml_dict, compiler) + + +@functools.cache +@functools.singledispatch +def translate_to_yaml(op: Any, compiler: Any) -> dict: + raise NotImplementedError(f"No translation rule for {type(op)}") + + +@functools.singledispatch +def _translate_type(dtype: dt.DataType) -> dict: + return freeze({"name": type(dtype).__name__, "nullable": dtype.nullable}) + + +@_translate_type.register(dt.Timestamp) +def _translate_timestamp_type(dtype: dt.Timestamp) -> dict: + base = {"name": "Timestamp", "nullable": dtype.nullable} + if dtype.timezone is not None: + base["timezone"] = dtype.timezone + return freeze(base) + + +@_translate_type.register(dt.Decimal) +def _translate_decimal_type(dtype: dt.Decimal) -> dict: + base = {"name": "Decimal", "nullable": dtype.nullable} + if dtype.precision is not None: + base["precision"] = dtype.precision + if dtype.scale is not None: + base["scale"] = dtype.scale + return freeze(base) + + +@_translate_type.register(dt.Array) +def _translate_array_type(dtype: dt.Array) -> dict: + return freeze( + { + "name": "Array", + "value_type": _translate_type(dtype.value_type), + "nullable": dtype.nullable, + } + ) + + +@_translate_type.register(dt.Map) +def _translate_map_type(dtype: dt.Map) -> dict: + return freeze( + { + "name": "Map", + "key_type": _translate_type(dtype.key_type), + "value_type": _translate_type(dtype.value_type), + "nullable": dtype.nullable, + } + ) + + +@_translate_type.register(dt.Interval) +def _tranlate_type_interval(dtype: dt.Interval) -> dict: + return freeze( + { + "name": "Interval", + "unit": _translate_temporal_unit(dtype.unit), + "nullable": dtype.nullable, + } + ) + + +@_translate_type.register(dt.Struct) +def _translate_struct_type(dtype: dt.Struct) -> dict: + return freeze( + { + "name": "Struct", + "fields": { + name: _translate_type(field_type) + for name, field_type in zip(dtype.names, dtype.types) + }, + "nullable": dtype.nullable, + } + ) + + +def _translate_temporal_unit(unit: tm.IntervalUnit) -> dict: + if unit.is_date(): + unit_name = "DateUnit" + elif unit.is_time(): + unit_name = "TimeUnit" + else: + unit_name = "IntervalUnit" + return freeze({"name": unit_name, "value": unit.value}) + + +def _translate_literal_value(value: Any, dtype: dt.DataType) -> Any: + if value is None: + return None + elif isinstance(value, (bool, int, float, str)): + return value + elif isinstance(value, decimal.Decimal): + return str(value) + elif isinstance(value, (datetime.datetime, datetime.date, datetime.time)): + return value.isoformat() + elif isinstance(value, list): + return [_translate_literal_value(v, dtype.value_type) for v in value] + elif isinstance(value, dict): + return { + _translate_literal_value(k, dtype.key_type): _translate_literal_value( + v, dtype.value_type + ) + for k, v in value.items() + } + else: + return value + + +@translate_to_yaml.register(ops.WindowFunction) +def _window_function_to_yaml(op: ops.WindowFunction, compiler: Any) -> dict: + result = { + "op": "WindowFunction", + "args": [translate_to_yaml(op.func, compiler)], + "type": _translate_type(op.dtype), + } + + if op.group_by: + result["group_by"] = [translate_to_yaml(expr, compiler) for expr in op.group_by] + + if op.order_by: + result["order_by"] = [translate_to_yaml(expr, compiler) for expr in op.order_by] + + if op.start is not None: + result["start"] = ( + translate_to_yaml(op.start.value, compiler)["value"] + if isinstance(op.start, ops.WindowBoundary) + else op.start + ) + + if op.end is not None: + result["end"] = ( + translate_to_yaml(op.end.value, compiler)["value"] + if isinstance(op.end, ops.WindowBoundary) + else op.end + ) + + return freeze(result) + + +@register_from_yaml_handler("WindowFunction") +def _window_function_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + func = translate_from_yaml(yaml_dict["args"][0], compiler) + group_by = [translate_from_yaml(g, compiler) for g in yaml_dict.get("group_by", [])] + order_by = [translate_from_yaml(o, compiler) for o in yaml_dict.get("order_by", [])] + start = ibis.literal(yaml_dict["start"]) if "start" in yaml_dict else None + end = ibis.literal(yaml_dict["end"]) if "end" in yaml_dict else None + window = ibis.window( + group_by=group_by, order_by=order_by, preceding=start, following=end + ) + return func.over(window) + + +@translate_to_yaml.register(ops.WindowBoundary) +def _window_boundary_to_yaml(op: ops.WindowBoundary, compiler: Any) -> dict: + return freeze( + { + "op": "WindowBoundary", + "value": translate_to_yaml(op.value, compiler), + "preceding": op.preceding, + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("WindowBoundary") +def _window_boundary_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + value = translate_from_yaml(yaml_dict["value"], compiler) + return ops.WindowBoundary(value, preceding=yaml_dict["preceding"]) + + +@translate_to_yaml.register(ops.Node) +def _base_op_to_yaml(op: ops.Node, compiler: Any) -> dict: + return freeze( + { + "op": type(op).__name__, + "args": [ + translate_to_yaml(arg, compiler) + for arg in op.args + if isinstance(arg, (ops.Value, ops.Node)) + ], + } + ) + + +@translate_to_yaml.register(ops.UnboundTable) +def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: + return freeze( + { + "op": "UnboundTable", + "name": op.name, + "schema": { + name: _translate_type(dtype) for name, dtype in op.schema.items() + }, + } + ) + + +@register_from_yaml_handler("UnboundTable") +def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + table_name = yaml_dict["name"] + schema = [ + (name, _type_from_yaml(dtype)) for name, dtype in yaml_dict["schema"].items() + ] + return ibis.table(schema, name=table_name) + + +@translate_to_yaml.register(ops.InMemoryTable) +def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: + if not hasattr(compiler, "table_data"): + compiler.table_data = {} + compiler.table_data[id(op)] = op.data + + return _unbound_table_to_yaml( + ops.UnboundTable(name=op.name, schema=op.schema), compiler + ) + + +@register_from_yaml_handler("InMemoryTable") +def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + return _unbound_table_from_yaml(yaml_dict, compiler) + + +@translate_to_yaml.register(ops.Literal) +def _literal_to_yaml(op: ops.Literal, compiler: Any) -> dict: + value = _translate_literal_value(op.value, op.dtype) + return freeze({"op": "Literal", "value": value, "type": _translate_type(op.dtype)}) + + +@register_from_yaml_handler("Literal") +def _literal_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + value = yaml_dict["value"] + dtype = _type_from_yaml(yaml_dict["type"]) + return ibis.literal(value, type=dtype) + + +@translate_to_yaml.register(ops.ValueOp) +def _value_op_to_yaml(op: ops.ValueOp, compiler: Any) -> dict: + return freeze( + { + "op": type(op).__name__, + "type": _translate_type(op.dtype), + "args": [ + translate_to_yaml(arg, compiler) + for arg in op.args + if isinstance(arg, (ops.Value, ops.Node)) + ], + } + ) + + +@register_from_yaml_handler("ValueOp") +def _value_op_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + method_name = yaml_dict["op"].lower() + method = getattr(args[0], method_name) + return method(*args[1:]) + + +@translate_to_yaml.register(ops.StringUnary) +def _string_unary_to_yaml(op: ops.StringUnary, compiler: Any) -> dict: + return freeze( + { + "op": type(op).__name__, + "args": [translate_to_yaml(op.arg, compiler)], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("StringUnary") +def _string_unary_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + method_name = yaml_dict["op"].lower() + return getattr(arg, method_name)() + + +@translate_to_yaml.register(ops.Substring) +def _substring_to_yaml(op: ops.Substring, compiler: Any) -> dict: + args = [ + translate_to_yaml(op.arg, compiler), + translate_to_yaml(op.start, compiler), + ] + if op.length is not None: + args.append(translate_to_yaml(op.length, compiler)) + return freeze({"op": "Substring", "args": args, "type": _translate_type(op.dtype)}) + + +@register_from_yaml_handler("Substring") +def _substring_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].substr(args[1], args[2] if len(args) > 2 else None) + + +@translate_to_yaml.register(ops.StringLength) +def _string_length_to_yaml(op: ops.StringLength, compiler: Any) -> dict: + return freeze( + { + "op": "StringLength", + "args": [translate_to_yaml(op.arg, compiler)], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("StringLength") +def _string_length_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return arg.length() + + +@translate_to_yaml.register(ops.StringConcat) +def _string_concat_to_yaml(op: ops.StringConcat, compiler: Any) -> dict: + return freeze( + { + "op": "StringConcat", + "args": [translate_to_yaml(arg, compiler) for arg in op.arg], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("StringConcat") +def _string_concat_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return functools.functools.reduce(lambda x, y: x.concat(y), args) + + +@translate_to_yaml.register(ops.BinaryOp) +def _binary_op_to_yaml(op: ops.BinaryOp, compiler: Any) -> dict: + return freeze( + { + "op": type(op).__name__, + "args": [ + translate_to_yaml(op.left, compiler), + translate_to_yaml(op.right, compiler), + ], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("BinaryOp") +def _binary_op_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + op_name = yaml_dict["op"].lower() + return getattr(args[0], op_name)(args[1]) + + +@translate_to_yaml.register(ops.Filter) +def _filter_to_yaml(op: ops.Filter, compiler: Any) -> dict: + return freeze( + { + "op": "Filter", + "parent": translate_to_yaml(op.parent, compiler), + "predicates": [translate_to_yaml(pred, compiler) for pred in op.predicates], + } + ) + + +@register_from_yaml_handler("Filter") +def _filter_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + predicates = [ + translate_from_yaml(pred, compiler) for pred in yaml_dict["predicates"] + ] + filter_op = ops.Filter(parent, predicates) + return filter_op.to_expr() + + +@translate_to_yaml.register(ops.Project) +def _project_to_yaml(op: ops.Project, compiler: Any) -> dict: + return freeze( + { + "op": "Project", + "parent": translate_to_yaml(op.parent, compiler), + "values": { + name: translate_to_yaml(val, compiler) + for name, val in op.values.items() + }, + } + ) + + +@register_from_yaml_handler("Project") +def _project_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + values = { + name: translate_from_yaml(val, compiler) + for name, val in yaml_dict["values"].items() + } + projected = parent.projection(values) + return projected + + +@translate_to_yaml.register(ops.Aggregate) +def _aggregate_to_yaml(op: ops.Aggregate, compiler: Any) -> dict: + return freeze( + { + "op": "Aggregate", + "parent": translate_to_yaml(op.parent, compiler), + "by": [translate_to_yaml(group, compiler) for group in op.groups.values()], + "metrics": { + name: translate_to_yaml(metric, compiler) + for name, metric in op.metrics.items() + }, + } + ) + + +@register_from_yaml_handler("Aggregate") +def _aggregate_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + groups = tuple( + translate_from_yaml(group, compiler) for group in yaml_dict.get("by", []) + ) + + raw_metrics = { + name: translate_from_yaml(metric, compiler) + for name, metric in yaml_dict.get("metrics", {}).items() + } + metrics = raw_metrics + + if groups: + return parent.group_by(list(groups)).aggregate(metrics) + else: + return parent.aggregate(metrics) + + +@translate_to_yaml.register(ops.JoinChain) +def _join_to_yaml(op: ops.JoinChain, compiler: Any) -> dict: + result = { + "op": "JoinChain", + "first": translate_to_yaml(op.first, compiler), + "rest": [ + { + "how": link.how, + "table": translate_to_yaml(link.table, compiler), + "predicates": [ + translate_to_yaml(pred, compiler) for pred in link.predicates + ], + } + for link in op.rest + ], + } + if hasattr(op, "values") and op.values: + result["values"] = { + name: translate_to_yaml(val, compiler) for name, val in op.values.items() + } + return freeze(result) + + +@register_from_yaml_handler("JoinChain") +def _join_chain_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + first = translate_from_yaml(yaml_dict["first"], compiler) + result = first + + for join in yaml_dict["rest"]: + table = translate_from_yaml(join["table"], compiler) + predicates = [ + translate_from_yaml(pred, compiler) for pred in join["predicates"] + ] + result = result.join(table, predicates, how=join["how"]) + + if "values" in yaml_dict: + values = { + name: translate_from_yaml(val, compiler) + for name, val in yaml_dict["values"].items() + } + result = result.select(values) + return result + + +@translate_to_yaml.register(ops.Sort) +def _sort_to_yaml(op: ops.Sort, compiler: Any) -> dict: + return freeze( + { + "op": "Sort", + "parent": translate_to_yaml(op.parent, compiler), + "keys": [translate_to_yaml(key, compiler) for key in op.keys], + } + ) + + +@register_from_yaml_handler("Sort") +def _sort_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + keys = tuple(translate_from_yaml(key, compiler) for key in yaml_dict["keys"]) + sort_op = ops.Sort(parent, keys=keys) + return sort_op.to_expr() + + +@translate_to_yaml.register(ops.SortKey) +def _sort_key_to_yaml(op: ops.SortKey, compiler: Any) -> dict: + return freeze( + { + "op": "SortKey", + "arg": translate_to_yaml(op.expr, compiler), + "ascending": op.ascending, + "nulls_first": op.nulls_first, + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("SortKey") +def _sort_key_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + expr = translate_from_yaml(yaml_dict["arg"], compiler) + ascending = yaml_dict.get("ascending", True) + nulls_first = yaml_dict.get("nulls_first", False) + return ops.SortKey(expr, ascending=ascending, nulls_first=nulls_first).to_expr() + + +@translate_to_yaml.register(ops.Limit) +def _limit_to_yaml(op: ops.Limit, compiler: Any) -> dict: + return freeze( + { + "op": "Limit", + "parent": translate_to_yaml(op.parent, compiler), + "n": op.n, + "offset": op.offset, + } + ) + + +@register_from_yaml_handler("Limit") +def _limit_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + return parent.limit(yaml_dict["n"], offset=yaml_dict["offset"]) + + +@translate_to_yaml.register(ops.ScalarSubquery) +def _scalar_subquery_to_yaml(op: ops.ScalarSubquery, compiler: Any) -> dict: + return freeze( + { + "op": "ScalarSubquery", + "args": [translate_to_yaml(arg, compiler) for arg in op.args], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("ScalarSubquery") +def _scalar_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + subquery = translate_from_yaml(yaml_dict["args"][0], compiler) + return ops.ScalarSubquery(subquery).to_expr() + + +@translate_to_yaml.register(ops.ExistsSubquery) +def _exists_subquery_to_yaml(op: ops.ExistsSubquery, compiler: Any) -> dict: + return freeze( + { + "op": "ExistsSubquery", + "rel": translate_to_yaml(op.rel, compiler), + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("ExistsSubquery") +def _exists_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + rel = translate_from_yaml(yaml_dict["rel"], compiler) + return ops.ExistsSubquery(rel).to_expr() + + +@translate_to_yaml.register(ops.InSubquery) +def _in_subquery_to_yaml(op: ops.InSubquery, compiler: Any) -> dict: + return freeze( + { + "op": "InSubquery", + "needle": translate_to_yaml(op.needle, compiler), + "haystack": translate_to_yaml(op.rel, compiler), + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("InSubquery") +def _in_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + needle = translate_from_yaml(yaml_dict["needle"], compiler) + haystack = translate_from_yaml(yaml_dict["haystack"], compiler) + return ops.InSubquery(haystack, needle).to_expr() + + +@translate_to_yaml.register(ops.Field) +def _field_to_yaml(op: ops.Field, compiler: Any) -> dict: + result = { + "op": "Field", + "name": op.name, + "relation": translate_to_yaml(op.rel, compiler), + "type": _translate_type(op.dtype), + } + if op.args and len(op.args) >= 2 and isinstance(op.args[1], str): + underlying_name = op.args[1] + if underlying_name != op.name: + result["original_name"] = underlying_name + return freeze(result) + + +@register_from_yaml_handler("Field") +def field_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + relation = translate_from_yaml(yaml_dict["relation"], compiler) + + target_name = yaml_dict["name"] + source_name = yaml_dict.get("original_name", target_name) + + schema = relation.schema() if callable(relation.schema) else relation.schema + + if source_name not in schema.names: + if target_name in schema.names: + source_name = target_name + else: + columns_formatted = ", ".join(schema.names) + raise IbisTypeError( + f"Column {source_name!r} not found in table. " + f"Existing columns: {columns_formatted}." + ) + field = relation[source_name] + + if target_name != source_name: + field = field.name(target_name) + + return freeze(field) + + +@translate_to_yaml.register(ops.InValues) +def _in_values_to_yaml(op: ops.InValues, compiler: Any) -> dict: + return freeze( + { + "op": "InValues", + "args": [ + translate_to_yaml(op.value, compiler), + *[translate_to_yaml(opt, compiler) for opt in op.options], + ], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("InValues") +def _in_values_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + value = translate_from_yaml(yaml_dict["args"][0], compiler) + options = tuple(translate_from_yaml(opt, compiler) for opt in yaml_dict["args"][1:]) + return ops.InValues(value, options).to_expr() + + +@translate_to_yaml.register(ops.SimpleCase) +def _simple_case_to_yaml(op: ops.SimpleCase, compiler: Any) -> dict: + return freeze( + { + "op": "SimpleCase", + "base": translate_to_yaml(op.base, compiler), + "cases": [translate_to_yaml(case, compiler) for case in op.cases], + "results": [translate_to_yaml(result, compiler) for result in op.results], + "default": translate_to_yaml(op.default, compiler), + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("SimpleCase") +def _simple_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + base = translate_from_yaml(yaml_dict["base"], compiler) + cases = tuple(translate_from_yaml(case, compiler) for case in yaml_dict["cases"]) + results = tuple( + translate_from_yaml(result, compiler) for result in yaml_dict["results"] + ) + default = translate_from_yaml(yaml_dict["default"], compiler) + return ops.SimpleCase(base, cases, results, default).to_expr() + + +@translate_to_yaml.register(ops.IfElse) +def _if_else_to_yaml(op: ops.IfElse, compiler: Any) -> dict: + return freeze( + { + "op": "IfElse", + "bool_expr": translate_to_yaml(op.bool_expr, compiler), + "true_expr": translate_to_yaml(op.true_expr, compiler), + "false_null_expr": translate_to_yaml(op.false_null_expr, compiler), + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("IfElse") +def _if_else_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + bool_expr = translate_from_yaml(yaml_dict["bool_expr"], compiler) + true_expr = translate_from_yaml(yaml_dict["true_expr"], compiler) + false_null_expr = translate_from_yaml(yaml_dict["false_null_expr"], compiler) + return ops.IfElse(bool_expr, true_expr, false_null_expr).to_expr() + + +@translate_to_yaml.register(ops.CountDistinct) +def _count_distinct_to_yaml(op: ops.CountDistinct, compiler: Any) -> dict: + return freeze( + { + "op": "CountDistinct", + "args": [translate_to_yaml(op.arg, compiler)], + "type": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("CountDistinct") +def _count_distinct_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return arg.nunique() + + +@translate_to_yaml.register(ops.SelfReference) +def _self_reference_to_yaml(op: ops.SelfReference, compiler: Any) -> dict: + result = {"op": "SelfReference", "identifier": op.identifier} + if op.args: + result["args"] = [translate_to_yaml(op.args[0], compiler)] + return freeze(result) + + +@register_from_yaml_handler("SelfReference") +def _self_reference_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + if "args" in yaml_dict and yaml_dict["args"]: + underlying = translate_from_yaml(yaml_dict["args"][0], compiler) + else: + if underlying is None: + raise NotImplementedError("No relation available for SelfReference") + + identifier = yaml_dict.get("identifier", 0) + ref = ops.SelfReference(underlying, identifier=identifier) + + return ref.to_expr() + + +@translate_to_yaml.register(ops.DropColumns) +def _drop_columns_to_yaml(op: ops.DropColumns, compiler: Any) -> dict: + return freeze( + { + "op": "DropColumns", + "parent": translate_to_yaml(op.parent, compiler), + "columns_to_drop": list(op.columns_to_drop), + } + ) + + +@register_from_yaml_handler("DropColumns") +def _drop_columns_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], compiler) + columns = frozenset(yaml_dict["columns_to_drop"]) + op = ops.DropColumns(parent, columns) + return op.to_expr() + + +@translate_to_yaml.register(ops.SearchedCase) +def _searched_case_to_yaml(op: ops.SearchedCase, compiler: Any) -> dict: + return freeze( + { + "op": "SearchedCase", + "cases": [translate_to_yaml(case, compiler) for case in op.cases], + "results": [translate_to_yaml(result, compiler) for result in op.results], + "default": translate_to_yaml(op.default, compiler), + "dtype": _translate_type(op.dtype), + } + ) + + +@register_from_yaml_handler("SearchedCase") +def _searched_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + cases = [translate_from_yaml(case, compiler) for case in yaml_dict["cases"]] + results = [translate_from_yaml(result, compiler) for result in yaml_dict["results"]] + default = translate_from_yaml(yaml_dict["default"], compiler) + op = ops.SearchedCase(cases, results, default) + return op.to_expr() + + +@translate_to_yaml.register(ops.ScalarUDF) +def _scalar_udf_to_yaml(op: ops.ScalarUDF, compiler: Any) -> dict: + arg_names = [ + name + for name in dir(op) + if not name.startswith("__") and name not in op.__class__.__slots__ + ] + + return freeze( + { + "op": "ScalarUDF", + "unique_name": op.__func_name__, + "input_type": "builtin", + "args": [translate_to_yaml(arg, compiler) for arg in op.args], + "type": _translate_type(op.dtype), + "pickle": serialize_udf_function(op.__func__), + "module": op.__module__, + "class_name": op.__class__.__name__, + "arg_names": arg_names, + } + ) + + +@register_from_yaml_handler("ScalarUDF") +def _scalar_udf_from_yaml(yaml_dict: dict, compiler: any) -> any: + encoded_fn = yaml_dict.get("pickle") + if not encoded_fn: + raise ValueError("Missing pickle data for ScalarUDF") + fn = deserialize_udf_function(encoded_fn) + + args = tuple( + translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", []) + ) + if not args: + raise ValueError("ScalarUDF requires at least one argument") + + arg_names = yaml_dict.get("arg_names", [f"arg{i}" for i in range(len(args))]) + + fields = { + name: Argument(pattern=rlz.ValueOf(arg.type()), typehint=arg.type()) + for name, arg in zip(arg_names, args) + } + + bases = (ops.ScalarUDF,) + meta = { + "dtype": dt.dtype(yaml_dict["type"]["name"]), + "__input_type__": ops.udf.InputType.BUILTIN, + "__func__": property(fget=lambda _, f=fn: f), + "__config__": {"volatility": "immutable"}, + "__udf_namespace__": None, + "__module__": yaml_dict.get("module", "__main__"), + "__func_name__": yaml_dict["unique_name"], + } + + kwds = {**fields, **meta} + class_name = yaml_dict.get("class_name", yaml_dict["unique_name"]) + + node = type( + class_name, + bases, + kwds, + ) + + return node(*args).to_expr() + + +@register_from_yaml_handler("View") +def _view_from_yaml(yaml_dict: dict, compiler: any) -> ir.Expr: + underlying = translate_from_yaml(yaml_dict["args"][0], compiler) + alias = yaml_dict.get("name") + if alias: + return underlying.alias(alias) + return underlying + + +@register_from_yaml_handler("Mean") +def _mean_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].mean() + + +@register_from_yaml_handler("Add", "Subtract", "Multiply", "Divide") +def _binary_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], compiler) + right = translate_from_yaml(yaml_dict["args"][1], compiler) + op_map = { + "Add": lambda left, right: left + right, + "Subtract": lambda left, right: left - right, + "Multiply": lambda left, right: left * right, + "Divide": lambda left, right: left / right, + } + op_func = op_map.get(yaml_dict["op"]) + if op_func is None: + raise ValueError(f"Unsupported arithmetic operation: {yaml_dict['op']}") + return op_func(left, right) + + +@register_from_yaml_handler("Sum") +def _sum_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].sum() + + +@register_from_yaml_handler("Min") +def _min_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].min() + + +@register_from_yaml_handler("Max") +def _max_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].max() + + +@register_from_yaml_handler("Abs") +def _abs_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return arg.abs() + + +@register_from_yaml_handler("Count") +def _count_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return arg.count() + + +@register_from_yaml_handler("JoinReference") +def _join_reference_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + table_yaml = yaml_dict["args"][0] + return translate_from_yaml(table_yaml, compiler) + + +@register_from_yaml_handler( + "Equals", "NotEquals", "GreaterThan", "GreaterEqual", "LessThan", "LessEqual" +) +def _binary_compare_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], compiler) + right = translate_from_yaml(yaml_dict["args"][1], compiler) + + op_map = { + "Equals": lambda left, right: left == right, + "NotEquals": lambda left, right: left != right, + "GreaterThan": lambda left, right: left > right, + "GreaterEqual": lambda left, right: left >= right, + "LessThan": lambda left, right: left < right, + "LessEqual": lambda left, right: left <= right, + } + + op_func = op_map.get(yaml_dict["op"]) + if op_func is None: + raise ValueError(f"Unsupported comparison operation: {yaml_dict['op']}") + return op_func(left, right) + + +@register_from_yaml_handler("Between") +def _between_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + return args[0].between(args[1], args[2]) + + +@register_from_yaml_handler("Greater", "Less") +def _boolean_ops_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] + op_name = yaml_dict["op"] + op_map = { + "Greater": lambda left, right: left > right, + "Less": lambda left, right: left < right, + } + return op_map[op_name](*args) + + +@register_from_yaml_handler("And") +def _boolean_and_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", [])] + if not args: + raise ValueError("And operator requires at least one argument") + return functools.reduce(lambda x, y: x & y, args) + + +@register_from_yaml_handler("Or") +def _boolean_or_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = [translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", [])] + if not args: + raise ValueError("Or operator requires at least one argument") + return functools.reduce(lambda x, y: x | y, args) + + +@register_from_yaml_handler("Not") +def _not_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return ~arg + + +@register_from_yaml_handler("IsNull") +def _is_null_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + return arg.isnull() + + +@register_from_yaml_handler( + "ExtractYear", + "ExtractMonth", + "ExtractDay", + "ExtractHour", + "ExtractMinute", + "ExtractSecond", +) +def _extract_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + op_map = { + "ExtractYear": lambda x: x.year(), + "ExtractMonth": lambda x: x.month(), + "ExtractDay": lambda x: x.day(), + "ExtractHour": lambda x: x.hour(), + "ExtractMinute": lambda x: x.minute(), + "ExtractSecond": lambda x: x.second(), + } + return op_map[yaml_dict["op"]](arg) + + +@register_from_yaml_handler("TimestampDiff") +def _timestamp_diff_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], compiler) + right = translate_from_yaml(yaml_dict["args"][1], compiler) + return left - right + + +@register_from_yaml_handler("TimestampAdd", "TimestampSub") +def _timestamp_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + timestamp = translate_from_yaml(yaml_dict["args"][0], compiler) + interval = translate_from_yaml(yaml_dict["args"][1], compiler) + if yaml_dict["op"] == "TimestampAdd": + return timestamp + interval + else: + return timestamp - interval + + +@register_from_yaml_handler("Cast") +def _cast_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + target_dtype = _type_from_yaml(yaml_dict["type"]) + return arg.cast(target_dtype) + + +@register_from_yaml_handler("CountStar") +def _count_star_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + + return ops.CountStar(arg).to_expr() + + +@register_from_yaml_handler("StringSQLLike") +def _string_sql_like_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + args = yaml_dict.get("args", []) + if not args: + raise ValueError("Missing arguments for StringSQLLike operator") + + col = translate_from_yaml(args[0], compiler) + + if len(args) >= 2: + pattern_expr = translate_from_yaml(args[1], compiler) + else: + pattern_value = args[0].get("value") + if pattern_value is None: + pattern_value = yaml_dict.get("value") + if pattern_value is None: + raise ValueError("Missing pattern for StringSQLLike operator") + pattern_expr = ibis.literal(pattern_value, type=dt.String()) + + escape = yaml_dict.get("escape") + + return ops.StringSQLLike(col, pattern_expr, escape=escape).to_expr() + + +def _type_from_yaml(yaml_dict: dict) -> dt.DataType: + if isinstance(yaml_dict, str): + raise ValueError( + f"Unexpected string value '{yaml_dict}' - type definitions should be dictionaries" + ) + type_name = yaml_dict["name"] + base_type = REVERSE_TYPE_REGISTRY.get(type_name) + if base_type is None: + raise ValueError(f"Unknown type: {type_name}") + if callable(base_type) and not isinstance(base_type, dt.DataType): + base_type = base_type(yaml_dict) + elif ( + "nullable" in yaml_dict + and isinstance(base_type, dt.DataType) + and not isinstance(base_type, (tm.IntervalUnit, dt.Timestamp)) + ): + base_type = base_type.copy(nullable=yaml_dict["nullable"]) + return base_type + + +REVERSE_TYPE_REGISTRY = { + "Int8": dt.Int8(), + "Int16": dt.Int16(), + "Int32": dt.Int32(), + "Int64": dt.Int64(), + "UInt8": dt.UInt8(), + "UInt16": dt.UInt16(), + "UInt32": dt.UInt32(), + "UInt64": dt.UInt64(), + "Float32": dt.Float32(), + "Float64": dt.Float64(), + "String": dt.String(), + "Boolean": dt.Boolean(), + "Date": dt.Date(), + "Time": dt.Time(), + "Binary": dt.Binary(), + "JSON": dt.JSON(), + "Null": dt.null, + "Timestamp": lambda yaml_dict: dt.Timestamp( + nullable=yaml_dict.get("nullable", True) + ), + "Decimal": lambda yaml_dict: dt.Decimal( + precision=yaml_dict.get("precision"), + scale=yaml_dict.get("scale"), + nullable=yaml_dict.get("nullable", True), + ), + "IntervalUnit": lambda yaml_dict: tm.IntervalUnit( + yaml_dict["value"] if isinstance(yaml_dict, dict) else yaml_dict + ), + "Interval": lambda yaml_dict: dt.Interval( + unit=_type_from_yaml(yaml_dict["unit"]), + nullable=yaml_dict.get("nullable", True), + ), + "DateUnit": lambda yaml_dict: tm.DateUnit(yaml_dict["value"]), + "TimeUnit": lambda yaml_dict: tm.TimeUnit(yaml_dict["value"]), + "TimestampUnit": lambda yaml_dict: tm.TimestampUnit(yaml_dict["value"]), + "Array": lambda yaml_dict: dt.Array( + _type_from_yaml(yaml_dict["value_type"]), + nullable=yaml_dict.get("nullable", True), + ), + "Map": lambda yaml_dict: dt.Map( + _type_from_yaml(yaml_dict["key_type"]), + _type_from_yaml(yaml_dict["value_type"]), + nullable=yaml_dict.get("nullable", True), + ), +} diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py new file mode 100644 index 00000000..b35a908a --- /dev/null +++ b/python/letsql/ibis_yaml/utils.py @@ -0,0 +1,127 @@ +import base64 +from collections.abc import Mapping, Sequence + +import cloudpickle +from ibis.common.collections import FrozenOrderedDict + + +def serialize_udf_function(fn: callable) -> str: + pickled = cloudpickle.dumps(fn) + encoded = base64.b64encode(pickled).decode("ascii") + return encoded + + +def deserialize_udf_function(encoded_fn: str) -> callable: + pickled = base64.b64decode(encoded_fn) + return cloudpickle.loads(pickled) + + +def freeze(obj): + if isinstance(obj, dict): + return FrozenOrderedDict({k: freeze(v) for k, v in obj.items()}) + elif isinstance(obj, list): + return tuple(freeze(x) for x in obj) + return obj + + +class MissingValue: + def __repr__(self): + return "" + + +MISSING = MissingValue() + + +def deep_diff_objects(obj1, obj2, path="root"): + differences = [] + + if obj1 is not obj2: + differences.append((path, obj1, obj2)) + return differences + + if isinstance(obj1, Mapping): + keys1 = set(obj1.keys()) + keys2 = set(obj2.keys()) + for key in keys1 - keys2: + diff_path = f"{path}.{key}" if path else key + differences.append((diff_path, obj1[key], MISSING)) + for key in keys2 - keys1: + diff_path = f"{path}.{key}" if path else key + differences.append((diff_path, MISSING, obj2[key])) + for key in keys1 & keys2: + diff_path = f"{path}.{key}" if path else key + differences.extend(deep_diff_objects(obj1[key], obj2[key], diff_path)) + return differences + + elif isinstance(obj1, Sequence) and not isinstance(obj1, str): + if len(obj1) != len(obj2): + differences.append((path, obj1, obj2)) + for i, (item1, item2) in enumerate(zip(obj1, obj2)): + diff_path = f"{path}[{i}]" + differences.extend(deep_diff_objects(item1, item2, diff_path)) + return differences + + else: + if obj1 != obj2: + differences.append((path, obj1, obj2)) + return differences + + +def serialize_ibis_expr(expr): + try: + op = expr.op() + except Exception: + return repr(expr) + + serialized = { + "expr_class": expr.__class__.__name__, + "op_class": op.__class__.__name__, + } + + op_attrs = {} + for attr in dir(op): + if attr.startswith("_"): + continue + try: + value = getattr(op, attr) + except Exception: + continue + if callable(value): + continue + op_attrs[attr] = value + if op_attrs: + serialized["op_attrs"] = op_attrs + + if hasattr(op, "args"): + try: + children = op.args + except Exception: + children = None + if children is not None: + if isinstance(children, Sequence) and not isinstance(children, str): + serialized["args"] = [serialize_ibis_expr(child) for child in children] + else: + serialized["args"] = serialize_ibis_expr(children) + return serialized + + +def diff_ibis_exprs(expr1, expr2): + if expr1.equals(expr2): + print("Expressions are equal") + return + + serialized1 = serialize_ibis_expr(expr1) + serialized2 = serialize_ibis_expr(expr2) + + diffs = deep_diff_objects(serialized1, serialized2) + if diffs: + print("Found differences:") + for diff in diffs: + path, val1, val2 = diff + print(f"At {path}:") + print(f" First expression: {val1}") + print(f" Second expression: {val2}") + else: + print("No differences found (unexpectedly).") + + return diffs diff --git a/requirements-dev.txt b/requirements-dev.txt index 680307f8..339c7efa 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,7 +25,7 @@ chardet==5.2.0 charset-normalizer==3.4.1 cityhash==0.4.7 ; python_full_version < '4.0' click==8.1.8 -cloudpickle==3.1.1 ; python_full_version < '4.0' +cloudpickle==3.1.1 codespell==2.4.1 colorama==0.4.6 comm==0.2.2 diff --git a/uv.lock b/uv.lock index fbfb5bf6..0932705f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -1144,6 +1145,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/30/5ef5994b090398f9284d2662f56853e5183ae2cb5d8e3db67e4f4cfea407/humanize-4.12.1-py3-none-any.whl", hash = "sha256:86014ca5c52675dffa1d404491952f1f5bf03b07c175a51891a343daebf01fea", size = 127409 }, ] +[[package]] +name = "hypothesis" +version = "6.126.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/8c/8281dd1408dd8374b0ed0528e63fb53a556b3d4f901382f51148345ec9fb/hypothesis-6.126.0.tar.gz", hash = "sha256:648b6215ee0468fa85eaee9dceb5b7766a5861c20ee4801bd904a2c02f1a6c9b", size = 420895 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/fc/8e1749aa79631952bf70e57913626fb0a0556eb2ad3c530c5526f6e5ba13/hypothesis-6.126.0-py3-none-any.whl", hash = "sha256:323c58a773482a2b4ba4e35202560cfcba45e8a8e09e7ffb83c0f9bac5b544da", size = 483657 }, +] + [[package]] name = "identify" version = "2.6.7" @@ -3327,9 +3342,11 @@ dependencies = [ { name = "atpublic" }, { name = "attrs", marker = "python_full_version < '4.0'" }, { name = "cityhash", marker = "python_full_version < '4.0'" }, + { name = "cloudpickle" }, { name = "connectorx", marker = "python_full_version < '4.0'" }, { name = "dask", marker = "python_full_version < '4.0'" }, { name = "geoarrow-types", marker = "python_full_version < '4.0'" }, + { name = "hypothesis" }, { name = "pandas", marker = "python_full_version < '4.0'" }, { name = "parsy" }, { name = "psycopg2-binary", marker = "python_full_version < '4.0'" }, @@ -3340,6 +3357,7 @@ dependencies = [ { name = "python-dateutil" }, { name = "pythran", marker = "sys_platform == 'darwin'" }, { name = "pytz" }, + { name = "pyyaml" }, { name = "sqlalchemy", marker = "python_full_version < '4.0'" }, { name = "sqlglot" }, { name = "structlog", marker = "python_full_version < '4.0'" }, @@ -3409,6 +3427,7 @@ requires-dist = [ { name = "atpublic", specifier = ">=5.1" }, { name = "attrs", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=24.0.0,<26" }, { name = "cityhash", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=0.4.7,<1" }, + { name = "cloudpickle", specifier = ">=3.1.1" }, { name = "connectorx", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=0.3.2,<0.5.0" }, { name = "dask", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = "==2025.1.0" }, { name = "datafusion", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'datafusion'", specifier = ">=0.6,<44" }, @@ -3416,6 +3435,7 @@ requires-dist = [ { name = "duckdb", marker = "extra == 'duckdb'", specifier = ">=1.1.3" }, { name = "fsspec", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=2024.6.1,<2025.2.1" }, { name = "geoarrow-types", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=0.2,<1" }, + { name = "hypothesis", specifier = ">=6.124.9" }, { name = "pandas", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=1.5.3,<3" }, { name = "parsy", specifier = ">=2" }, { name = "pins", extras = ["gcs"], marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=0.8.3,<1" }, @@ -3427,6 +3447,7 @@ requires-dist = [ { name = "python-dateutil", specifier = ">=2.8.2" }, { name = "pythran", marker = "sys_platform == 'darwin'", specifier = ">=0.17.0" }, { name = "pytz", specifier = ">=2022.7" }, + { name = "pyyaml", specifier = ">=6.0.2" }, { name = "quickgrove", marker = "extra == 'examples'", specifier = ">=0.1.2" }, { name = "quickgrove", marker = "extra == 'quickgrove'", specifier = ">=0.1.2" }, { name = "snowflake-connector-python", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'snowflake'", specifier = ">=3.10.1,<4" }, @@ -3437,6 +3458,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.3.0" }, { name = "xgboost", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=1.6.1" }, ] +provides-extras = ["duckdb", "datafusion", "snowflake", "quickgrove", "examples"] [package.metadata.requires-dev] dev = [ From 3e05b855c5092093b672e6e38c74411dfddb05d2 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 8 Feb 2025 16:35:51 -0500 Subject: [PATCH 02/45] feat: add letsql RemoteTable op --- python/letsql/ibis_yaml/compiler.py | 5 +- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 108 ++++++++++++++++ python/letsql/ibis_yaml/translate.py | 120 ++++++++++++++++-- 3 files changed, 221 insertions(+), 12 deletions(-) create mode 100644 python/letsql/ibis_yaml/tests/test_letsql_ops.py diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 6610472b..ea83f1b6 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -6,10 +6,7 @@ def __init__(self): pass def compile_to_yaml(self, expr): - self.current_relation = None - unbound_expr = expr.unbind() - return translate_to_yaml(unbound_expr.op(), self) + return translate_to_yaml(expr.op(), self) def compile_from_yaml(self, yaml_dict): - self.current_relation = None return translate_from_yaml(yaml_dict, self) diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py new file mode 100644 index 00000000..bc3bc258 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -0,0 +1,108 @@ +import pytest + +import letsql as ls +from letsql.expr.relations import into_backend +from letsql.ibis_yaml.compiler import IbisYamlCompiler + + +@pytest.fixture(scope="session") +def duckdb_path(tmp_path_factory): + db_path = tmp_path_factory.mktemp("duckdb") / "test.db" + return str(db_path) + + +@pytest.fixture(scope="session") +def prepare_duckdb_con(duckdb_path): + """Load some test data into the DuckDB file outside the main test.""" + con = ls.duckdb.connect(duckdb_path) + con.profile_name = "my_duckdb" # patch + + con.raw_sql( + """ + CREATE TABLE IF NOT EXISTS mytable ( + id INT, + val VARCHAR + ) + """ + ) + con.raw_sql( + """ + INSERT INTO mytable + SELECT i, 'val' || i::VARCHAR + FROM range(1, 6) t(i) + """ + ) + return con + + +def test_duckdb_database_table_roundtrip(prepare_duckdb_con): + con = prepare_duckdb_con + + profiles = {"my_duckdb": con} + + table_expr = con.table("mytable") # DatabaseTable op + + expr1 = table_expr.mutate(new_val=(table_expr.val + "_extra")) + compiler = IbisYamlCompiler() + compiler.profiles = profiles + + yaml_dict = compiler.compile_to_yaml(expr1) + + print("Serialized YAML:\n", yaml_dict) + + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + + df_original = expr1.execute() + df_roundtrip = roundtrip_expr.execute() + + assert df_original.equals(df_roundtrip), "Roundtrip expression data differs!" + + +def test_memtable(prepare_duckdb_con, tmp_path_factory): + table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) + backend = table._find_backend() + backend.profile_name = "default-duckdb" + expr = table.mutate(new_val=2 * ls._.val) + + profiles = {"default-duckdb": backend} + + compiler = IbisYamlCompiler() + compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler.profiles = profiles + + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + + expr.equals(roundtrip_expr) + + assert expr.execute().equals(roundtrip_expr.execute()) + + +def test_into_backend(prepare_duckdb_con, tmp_path_factory): + table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) + backend = table._find_backend() + backend.profile_name = "default-duckdb" + expr = table.mutate(new_val=2 * ls._.val) + + con2 = ls.connect() + con2.profile_name = "default-let" + con3 = ls.connect() + con3.profile_name = "default-datafusion" + + expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * ls._.val) + expr = into_backend(expr, con3, "df_mem") + + profiles = { + "default-duckdb": backend, + "default-let": con2, + "default-datafusion": con3, + } + + compiler = IbisYamlCompiler() + compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler.profiles = profiles + + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + + assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 89ebe7ae..cfba74ea 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -11,9 +11,12 @@ import ibis.expr.operations.temporal as tm import ibis.expr.rules as rlz import ibis.expr.types as ir +import pyarrow.parquet as pq from ibis.common.annotations import Argument from ibis.common.exceptions import IbisTypeError +import letsql as ls +from letsql.expr.relations import RemoteTable, into_backend from letsql.ibis_yaml.utils import ( deserialize_udf_function, freeze, @@ -151,6 +154,11 @@ def _translate_literal_value(value: Any, dtype: dt.DataType) -> Any: return value +@translate_to_yaml.register(ir.Expr) +def _expr_to_yaml(expr: ir.Expr, compiler: any) -> dict: + return translate_to_yaml(expr.op(), compiler) + + @translate_to_yaml.register(ops.WindowFunction) def _window_function_to_yaml(op: ops.WindowFunction, compiler: Any) -> dict: result = { @@ -249,20 +257,109 @@ def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: return ibis.table(schema, name=table_name) +@translate_to_yaml.register(ops.DatabaseTable) +def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: + profile_name = getattr(op.source, "profile_name", None) + return freeze( + { + "op": "DatabaseTable", + "table": op.name, + "schema": { + name: _translate_type(dtype) for name, dtype in op.schema.items() + }, + "profile": profile_name, + } + ) + + +@register_from_yaml_handler("DatabaseTable") +def _database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: + profile_name = yaml_dict.get("profile") + table_name = yaml_dict.get("table") + if not profile_name or not table_name: + raise ValueError( + "Missing 'profile' or 'table' information in YAML for DatabaseTable." + ) + + try: + con = compiler.profiles[profile_name] + except KeyError: + raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + + return con.table(table_name) + + @translate_to_yaml.register(ops.InMemoryTable) def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: - if not hasattr(compiler, "table_data"): - compiler.table_data = {} - compiler.table_data[id(op)] = op.data + if not hasattr(compiler, "tmp_path"): + raise ValueError( + "Compiler is missing the 'tmp_path' attribute for memtable serialization" + ) + + arrow_table = op.data.to_pyarrow(op.schema) + + file_path = compiler.tmp_path / f"memtable_{id(op)}.parquet" + pq.write_table(arrow_table, str(file_path)) - return _unbound_table_to_yaml( - ops.UnboundTable(name=op.name, schema=op.schema), compiler + return freeze( + { + "op": "InMemoryTable", + "table": op.name, + "schema": { + name: _translate_type(dtype) for name, dtype in op.schema.items() + }, + "file": str(file_path), + } ) @register_from_yaml_handler("InMemoryTable") -def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - return _unbound_table_from_yaml(yaml_dict, compiler) +def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: + file_path = yaml_dict["file"] + arrow_table = pq.read_table(file_path) + df = arrow_table.to_pandas() + + table_name = yaml_dict.get("table", "memtable") + + memtable_expr = ls.memtable(df, columns=list(df.columns), name=table_name) + return memtable_expr + + +@translate_to_yaml.register(RemoteTable) +def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: + profile_name = getattr(op.source, "profile_name", None) + remote_expr_yaml = translate_to_yaml(op.remote_expr, compiler) + return freeze( + { + "op": "RemoteTable", # use a distinct op key + "table": op.name, # the table’s name (e.g. "ls_mem") + "schema": { + name: _translate_type(dtype) for name, dtype in op.schema.items() + }, + "profile": profile_name, # which connection to use on restore + "remote_expr": remote_expr_yaml, # the remote expression that was “injected” + } + ) + + +@register_from_yaml_handler("RemoteTable") +def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: + profile_name = yaml_dict.get("profile") + table_name = yaml_dict.get("table") + remote_expr_yaml = yaml_dict.get("remote_expr") + if not profile_name or not table_name or remote_expr_yaml is None: + raise ValueError( + "Missing keys in RemoteTable YAML; ensure 'profile', 'table', and 'remote_expr' are present." + ) + try: + con = compiler.profiles[profile_name] + except KeyError: + raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + + remote_expr = translate_from_yaml(remote_expr_yaml, compiler) + + remote_table_expr = into_backend(remote_expr, con, table_name) + return remote_table_expr @translate_to_yaml.register(ops.Literal) @@ -367,7 +464,7 @@ def _string_concat_to_yaml(op: ops.StringConcat, compiler: Any) -> dict: @register_from_yaml_handler("StringConcat") def _string_concat_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] - return functools.functools.reduce(lambda x, y: x.concat(y), args) + return functools.reduce(lambda x, y: x.concat(y), args) @translate_to_yaml.register(ops.BinaryOp) @@ -911,6 +1008,13 @@ def _binary_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: return op_func(left, right) +@register_from_yaml_handler("Repeat") +def _repeat_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.expr.types.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], compiler) + times = translate_from_yaml(yaml_dict["args"][1], compiler) + return ops.Repeat(arg, times).to_expr() + + @register_from_yaml_handler("Sum") def _sum_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] From 988562909c77a7c71d839475fa64a419df48e4d7 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 8 Feb 2025 17:02:14 -0500 Subject: [PATCH 03/45] wip: SourceStorage doesnt have profile_name in con --- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 21 ++++- python/letsql/ibis_yaml/tests/test_tpch.py | 11 --- python/letsql/ibis_yaml/translate.py | 91 +++++++++++++++++-- 3 files changed, 105 insertions(+), 18 deletions(-) diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index bc3bc258..7049fd6c 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -13,7 +13,6 @@ def duckdb_path(tmp_path_factory): @pytest.fixture(scope="session") def prepare_duckdb_con(duckdb_path): - """Load some test data into the DuckDB file outside the main test.""" con = ls.duckdb.connect(duckdb_path) con.profile_name = "my_duckdb" # patch @@ -106,3 +105,23 @@ def test_into_backend(prepare_duckdb_con, tmp_path_factory): roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) + + +def test_memtable_cache(prepare_duckdb_con, tmp_path_factory): + table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) + backend = table._find_backend() + backend.profile_name = "default-duckdb" + expr = table.mutate(new_val=2 * ls._.val).cache() + + profiles = {"default-duckdb": backend} + + compiler = IbisYamlCompiler() + compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler.profiles = profiles + + yaml_dict = compiler.compile_to_yaml(expr) + roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + + expr.equals(roundtrip_expr) + + assert expr.execute().equals(roundtrip_expr.execute()) diff --git a/python/letsql/ibis_yaml/tests/test_tpch.py b/python/letsql/ibis_yaml/tests/test_tpch.py index ecab4a92..a8e4d725 100644 --- a/python/letsql/ibis_yaml/tests/test_tpch.py +++ b/python/letsql/ibis_yaml/tests/test_tpch.py @@ -1,8 +1,5 @@ import pytest -import letsql.ibis_yaml -import letsql.ibis_yaml.utils - TPC_H = [ "tpc_h01", @@ -32,18 +29,10 @@ @pytest.mark.parametrize("fixture_name", TPC_H) def test_yaml_roundtrip(fixture_name, compiler, request): - compiler = letsql.ibis_yaml.compiler.IbisYamlCompiler() query = request.getfixturevalue(fixture_name) yaml_dict = compiler.compile_to_yaml(query) - print("Original Query:") - print(query) - roundtrip_query = compiler.compile_from_yaml(yaml_dict) - print("Roundtrip Query:") - print(roundtrip_query) - - letsql.ibis_yaml.utils.diff_ibis_exprs(query, roundtrip_query) assert roundtrip_query.equals(query), ( f"Roundtrip expression for {fixture_name} does not match the original." diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index cfba74ea..59ded35d 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -3,6 +3,7 @@ import datetime import decimal import functools +import pathlib from typing import Any import ibis @@ -16,7 +17,7 @@ from ibis.common.exceptions import IbisTypeError import letsql as ls -from letsql.expr.relations import RemoteTable, into_backend +from letsql.expr.relations import CachedNode, RemoteTable, into_backend from letsql.ibis_yaml.utils import ( deserialize_udf_function, freeze, @@ -289,6 +290,47 @@ def _database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: return con.table(table_name) +@translate_to_yaml.register(CachedNode) +def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: + return freeze( + { + "op": "CachedNode", + "schema": { + name: _translate_type(dtype) for name, dtype in op.schema.items() + }, + "parent": translate_to_yaml(op.parent, compiler), + "source": getattr(op.source, "profile_name", None), + "storage": translate_storage(op.storage, compiler), + "values": dict(op.values), + } + ) + + +@register_from_yaml_handler("CachedNode") +def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: + schema = { + name: _type_from_yaml(dtype_yaml) + for name, dtype_yaml in yaml_dict["schema"].items() + } + parent_expr = translate_from_yaml(yaml_dict["parent"], compiler) + profile_name = yaml_dict.get("source") + try: + source = compiler.profiles[profile_name] + except KeyError: + raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + storage = load_storage_from_yaml(yaml_dict["storage"], compiler) + values = yaml_dict.get("values", {}) + + op = CachedNode( + schema=schema, + parent=parent_expr.op(), + source=source, + storage=storage, + values=values, + ) + return op.to_expr() + + @translate_to_yaml.register(ops.InMemoryTable) def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: if not hasattr(compiler, "tmp_path"): @@ -331,13 +373,13 @@ def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: remote_expr_yaml = translate_to_yaml(op.remote_expr, compiler) return freeze( { - "op": "RemoteTable", # use a distinct op key - "table": op.name, # the table’s name (e.g. "ls_mem") + "op": "RemoteTable", + "table": op.name, "schema": { name: _translate_type(dtype) for name, dtype in op.schema.items() }, - "profile": profile_name, # which connection to use on restore - "remote_expr": remote_expr_yaml, # the remote expression that was “injected” + "profile": profile_name, + "remote_expr": remote_expr_yaml, } ) @@ -347,7 +389,7 @@ def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: profile_name = yaml_dict.get("profile") table_name = yaml_dict.get("table") remote_expr_yaml = yaml_dict.get("remote_expr") - if not profile_name or not table_name or remote_expr_yaml is None: + if profile_name is None: raise ValueError( "Missing keys in RemoteTable YAML; ensure 'profile', 'table', and 'remote_expr' are present." ) @@ -1259,3 +1301,40 @@ def _type_from_yaml(yaml_dict: dict) -> dt.DataType: nullable=yaml_dict.get("nullable", True), ), } + +# === Helper functions for translating cache storage === + + +def translate_storage(storage, compiler: any) -> dict: + from letsql.common.caching import ParquetStorage, SourceStorage + + if isinstance(storage, ParquetStorage): + return {"type": "ParquetStorage", "path": str(storage.path)} + elif isinstance(storage, SourceStorage): + return { + "type": "SourceStorage", + "source": getattr(storage.source, "profile_name", None), + } + else: + raise NotImplementedError(f"Unknown storage type: {type(storage)}") + + +def load_storage_from_yaml(storage_yaml: dict, compiler: any): + from letsql.expr.relations import ParquetStorage, _SourceStorage + + if storage_yaml["type"] == "ParquetStorage": + default_profile = list(compiler.profiles.values())[0] + return ParquetStorage( + source=default_profile, path=pathlib.Path(storage_yaml["path"]) + ) + elif storage_yaml["type"] == "SourceStorage": + source_profile_name = storage_yaml["source"] + try: + source = compiler.profiles[source_profile_name] + except KeyError: + raise ValueError( + f"Source profile {source_profile_name!r} not found in compiler.profiles" + ) + return _SourceStorage(source=source) + else: + raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") From a4fb8c8da458c0f2079925e848ba7293ce583520 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 8 Feb 2025 17:02:14 -0500 Subject: [PATCH 04/45] feat: add SourceStorage cache serde --- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 8 +++--- python/letsql/ibis_yaml/translate.py | 6 ++-- python/letsql/ibis_yaml/utils.py | 28 +++++++++++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index 7049fd6c..e453b34c 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -112,8 +112,10 @@ def test_memtable_cache(prepare_duckdb_con, tmp_path_factory): backend = table._find_backend() backend.profile_name = "default-duckdb" expr = table.mutate(new_val=2 * ls._.val).cache() + backend1 = expr._find_backend() + backend1.profile_name = "default-let" - profiles = {"default-duckdb": backend} + profiles = {"default-duckdb": backend, "default-let": backend1} compiler = IbisYamlCompiler() compiler.tmp_path = tmp_path_factory.mktemp("duckdb") @@ -122,6 +124,4 @@ def test_memtable_cache(prepare_duckdb_con, tmp_path_factory): yaml_dict = compiler.compile_to_yaml(expr) roundtrip_expr = compiler.compile_from_yaml(yaml_dict) - expr.equals(roundtrip_expr) - - assert expr.execute().equals(roundtrip_expr.execute()) + assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 59ded35d..d63a4261 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -21,7 +21,9 @@ from letsql.ibis_yaml.utils import ( deserialize_udf_function, freeze, + load_storage_from_yaml, serialize_udf_function, + translate_storage, ) @@ -319,14 +321,12 @@ def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: except KeyError: raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") storage = load_storage_from_yaml(yaml_dict["storage"], compiler) - values = yaml_dict.get("values", {}) op = CachedNode( schema=schema, - parent=parent_expr.op(), + parent=parent_expr, source=source, storage=storage, - values=values, ) return op.to_expr() diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index b35a908a..e5e46eea 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -4,6 +4,8 @@ import cloudpickle from ibis.common.collections import FrozenOrderedDict +from letsql.common.caching import ParquetStorage, SourceStorage + def serialize_udf_function(fn: callable) -> str: pickled = cloudpickle.dumps(fn) @@ -125,3 +127,29 @@ def diff_ibis_exprs(expr1, expr2): print("No differences found (unexpectedly).") return diffs + + +def translate_storage(storage, compiler: any) -> dict: + if isinstance(storage, ParquetStorage): + return {"type": "ParquetStorage", "path": str(storage.path)} + elif isinstance(storage, SourceStorage): + return { + "type": "SourceStorage", + "source": getattr(storage.source, "profile_name", None), + } + else: + raise NotImplementedError(f"Unknown storage type: {type(storage)}") + + +def load_storage_from_yaml(storage_yaml: dict, compiler: any): + if storage_yaml["type"] == "SourceStorage": + source_profile_name = storage_yaml["source"] + try: + source = compiler.profiles[source_profile_name] + except KeyError: + raise ValueError( + f"Source profile {source_profile_name!r} not found in compiler.profiles" + ) + return SourceStorage(source=source) + else: + raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") From dc630ec9cb6e0f46a25520e601ff3fe023b3e3ef Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 15 Feb 2025 11:10:58 -0500 Subject: [PATCH 05/45] feat: add BuildManager --- python/letsql/ibis_yaml/compiler.py | 82 +++++++++++++++++++ .../letsql/ibis_yaml/tests/test_compiler.py | 41 ++++++++++ 2 files changed, 123 insertions(+) create mode 100644 python/letsql/ibis_yaml/tests/test_compiler.py diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index ea83f1b6..e87302f1 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -1,6 +1,88 @@ +import pathlib +from pathlib import Path +from typing import Any, Dict + +import dask +import yaml + from letsql.ibis_yaml.translate import translate_from_yaml, translate_to_yaml +class StorageHandler: + def __init__(self, root_path: pathlib.Path): + self.root_path = ( + Path(root_path) if not isinstance(root_path, Path) else root_path + ) + self.root_path.mkdir(parents=True, exist_ok=True) + + def get_path(self, *parts) -> pathlib.Path: + return self.root_path.joinpath(*parts) + + def ensure_dir(self, *parts) -> pathlib.Path: + path = self.get_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: + path = self.get_path(*path_parts) + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open("w") as f: + yaml.dump( + data, + f, + default_flow_style=False, + sort_keys=False, + ) + return path + + def read_yaml(self, *path_parts) -> Dict[str, Any]: + path = self.get_path(*path_parts) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + with path.open("r") as f: + return yaml.safe_load(f) + + def write_text(self, content: str, *path_parts) -> pathlib.Path: + path = self.get_path(*path_parts) + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open("w") as f: + f.write(content) + return path + + def read_text(self, *path_parts) -> str: + path = self.get_path(*path_parts) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + with path.open("r") as f: + return f.read() + + def exists(self, *path_parts) -> bool: + return self.get_path(*path_parts).exists() + + +class BuildManager: + def __init__(self, storage_path: pathlib.Path): + self.storage = StorageHandler(storage_path) + + def get_expr_hash(self, expr) -> str: + expr_hash = dask.base.tokenize(expr) + return expr_hash[:12] # TODO: make length of hash as a config + + def save_yaml(self, yaml_dict: Dict[str, Any], expr) -> pathlib.Path: + expr_hash = self.get_expr_hash(expr) + return self.storage.write_yaml(yaml_dict, expr_hash, "expr.yaml") + + def load_yaml(self, expr_hash: str) -> Dict[str, Any]: + return self.storage.read_yaml(expr_hash, "expr.yaml") + + def get_build_path(self, expr_hash: str) -> pathlib.Path: + return self.storage.ensure_dir(expr_hash) + + class IbisYamlCompiler: def __init__(self): pass diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py new file mode 100644 index 00000000..21141f82 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -0,0 +1,41 @@ +import os + +import pytest + +from letsql.ibis_yaml.compiler import BuildManager + + +@pytest.fixture +def build_dir(tmp_path_factory): + return tmp_path_factory.mktemp("builds") + + +def test_build_manager_expr_hash(t, build_dir): + expected = "c6527994ad9a" + build_manager = BuildManager(build_dir) + result = build_manager.get_expr_hash(t) + assert expected == result + + +def test_build_manager_roundtrip(t, build_dir): + build_manager = BuildManager(build_dir) + expr_hash = "c6527994ad9a" + yaml_dict = {"a": "string"} + build_manager.save_yaml(yaml_dict, t) + + with open(build_dir / expr_hash / "expr.yaml") as f: + out = f.read() + assert out == "a: string\n" + result = build_manager.load_yaml(expr_hash) + assert result == yaml_dict + + +def test_build_manager_paths(t, build_dir): + new_path = build_dir / "new_path" + + assert not os.path.exists(new_path) + build_manager = BuildManager(new_path) + assert os.path.exists(new_path) + + build_manager.get_build_path("hash") + assert os.path.exists(new_path / "hash") From 77430503190fad0025e79b466254032eabd7e5e3 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 15 Feb 2025 12:10:01 -0500 Subject: [PATCH 06/45] refactor: IbisYamlCompiler with state --- python/letsql/ibis_yaml/compiler.py | 57 ++++++++++++++++--- python/letsql/ibis_yaml/tests/conftest.py | 9 ++- .../letsql/ibis_yaml/tests/test_compiler.py | 39 ++++++++++--- .../letsql/ibis_yaml/tests/test_join_chain.py | 3 - .../letsql/ibis_yaml/tests/test_letsql_ops.py | 19 +++---- python/letsql/ibis_yaml/translate.py | 6 +- 6 files changed, 98 insertions(+), 35 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index e87302f1..71a2c1ab 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -3,11 +3,23 @@ from typing import Any, Dict import dask +import ibis.expr.types as ir import yaml +from ibis.common.collections import FrozenOrderedDict from letsql.ibis_yaml.translate import translate_from_yaml, translate_to_yaml +class CleanDictYAMLDumper(yaml.SafeDumper): + def represent_frozenordereddict(self, data): + return self.represent_dict(dict(data)) + + +CleanDictYAMLDumper.add_representer( + FrozenOrderedDict, CleanDictYAMLDumper.represent_frozenordereddict +) + + class StorageHandler: def __init__(self, root_path: pathlib.Path): self.root_path = ( @@ -31,6 +43,7 @@ def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: yaml.dump( data, f, + Dumper=CleanDictYAMLDumper, default_flow_style=False, sort_keys=False, ) @@ -72,8 +85,7 @@ def get_expr_hash(self, expr) -> str: expr_hash = dask.base.tokenize(expr) return expr_hash[:12] # TODO: make length of hash as a config - def save_yaml(self, yaml_dict: Dict[str, Any], expr) -> pathlib.Path: - expr_hash = self.get_expr_hash(expr) + def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash) -> pathlib.Path: return self.storage.write_yaml(yaml_dict, expr_hash, "expr.yaml") def load_yaml(self, expr_hash: str) -> Dict[str, Any]: @@ -84,11 +96,40 @@ def get_build_path(self, expr_hash: str) -> pathlib.Path: class IbisYamlCompiler: - def __init__(self): - pass + def __init__(self, build_dir, build_manager=BuildManager): + self.build_manager = build_manager(build_dir) + self.current_path = None + + def compile(self, expr): + yaml_dict = self.compile_to_yaml(expr) + expr_hash = self.build_manager.get_expr_hash(expr) + self.curent_path = self.build_manager.get_build_path(expr_hash) + self.build_manager.save_yaml(yaml_dict, expr_hash) + + def from_hash(self, expr_hash) -> ir.Expr: + yaml_dict = self.build_manager.load_yaml(expr_hash) + + # this is needed for cache to work with ForzenOrderedDict + def convert_to_frozen(d): + if isinstance(d, dict): + items = [] + for k, v in d.items(): + converted_v = convert_to_frozen(v) + if isinstance(converted_v, list): + converted_v = tuple(converted_v) + items.append((k, converted_v)) + return FrozenOrderedDict(items) + elif isinstance(d, list): + return [convert_to_frozen(x) for x in d] + return d + + yaml_dict = convert_to_frozen(yaml_dict) + return self.compile_from_yaml(yaml_dict) + + def compile_from_yaml(self, yaml_dict) -> ir.Expr: + return translate_from_yaml(yaml_dict, self) - def compile_to_yaml(self, expr): + def compile_to_yaml(self, expr) -> Dict: + expr_hash = self.build_manager.get_expr_hash(expr) + self.current_path = self.build_manager.get_build_path(expr_hash) return translate_to_yaml(expr.op(), self) - - def compile_from_yaml(self, yaml_dict): - return translate_from_yaml(yaml_dict, self) diff --git a/python/letsql/ibis_yaml/tests/conftest.py b/python/letsql/ibis_yaml/tests/conftest.py index 082d411e..f2d79d6d 100644 --- a/python/letsql/ibis_yaml/tests/conftest.py +++ b/python/letsql/ibis_yaml/tests/conftest.py @@ -781,7 +781,12 @@ def tpc_h22(customer, orders): @pytest.fixture -def compiler(): +def build_dir(tmp_path_factory): + return tmp_path_factory.mktemp("builds") + + +@pytest.fixture +def compiler(build_dir): from letsql.ibis_yaml.compiler import IbisYamlCompiler - return IbisYamlCompiler() + return IbisYamlCompiler(build_dir) diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index 21141f82..61afc7c8 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -1,13 +1,8 @@ import os -import pytest +from ibis.common.collections import FrozenOrderedDict -from letsql.ibis_yaml.compiler import BuildManager - - -@pytest.fixture -def build_dir(tmp_path_factory): - return tmp_path_factory.mktemp("builds") +from letsql.ibis_yaml.compiler import BuildManager, IbisYamlCompiler def test_build_manager_expr_hash(t, build_dir): @@ -21,7 +16,7 @@ def test_build_manager_roundtrip(t, build_dir): build_manager = BuildManager(build_dir) expr_hash = "c6527994ad9a" yaml_dict = {"a": "string"} - build_manager.save_yaml(yaml_dict, t) + build_manager.save_yaml(yaml_dict, expr_hash) with open(build_dir / expr_hash / "expr.yaml") as f: out = f.read() @@ -39,3 +34,31 @@ def test_build_manager_paths(t, build_dir): build_manager.get_build_path("hash") assert os.path.exists(new_path / "hash") + + +def test_clean_frozen_dict_yaml(build_dir): + build_manager = BuildManager(build_dir) + data = FrozenOrderedDict( + {"string": "text", "integer": 42, "float": 3.14, "boolean": True, "none": None} + ) + + expected_yaml = """string: text +integer: 42 +float: 3.14 +boolean: true +none: null +""" + out_path = build_manager.save_yaml(data, "hash") + result = out_path.read_text() + + assert expected_yaml == result + + +def test_ibis_compiler(t, build_dir): + compiler = IbisYamlCompiler(build_dir) + compiler.compile(t) + expr_hash = "c6527994ad9a" + + roundtrip_expr = compiler.from_hash(expr_hash) + + assert t.equals(roundtrip_expr) diff --git a/python/letsql/ibis_yaml/tests/test_join_chain.py b/python/letsql/ibis_yaml/tests/test_join_chain.py index 3fffe466..11fc10f5 100644 --- a/python/letsql/ibis_yaml/tests/test_join_chain.py +++ b/python/letsql/ibis_yaml/tests/test_join_chain.py @@ -1,8 +1,6 @@ import ibis import pytest -from letsql.ibis_yaml.compiler import IbisYamlCompiler - @pytest.fixture def orders(): @@ -91,7 +89,6 @@ def test_minimal_joinchain_self_reference( ) ) - compiler = IbisYamlCompiler() yaml_dict = compiler.compile_to_yaml(q) q_roundtrip = compiler.compile_from_yaml(yaml_dict) diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index e453b34c..7b826322 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -34,7 +34,7 @@ def prepare_duckdb_con(duckdb_path): return con -def test_duckdb_database_table_roundtrip(prepare_duckdb_con): +def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): con = prepare_duckdb_con profiles = {"my_duckdb": con} @@ -42,7 +42,7 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con): table_expr = con.table("mytable") # DatabaseTable op expr1 = table_expr.mutate(new_val=(table_expr.val + "_extra")) - compiler = IbisYamlCompiler() + compiler = IbisYamlCompiler(build_dir) compiler.profiles = profiles yaml_dict = compiler.compile_to_yaml(expr1) @@ -57,7 +57,7 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con): assert df_original.equals(df_roundtrip), "Roundtrip expression data differs!" -def test_memtable(prepare_duckdb_con, tmp_path_factory): +def test_memtable(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() backend.profile_name = "default-duckdb" @@ -65,8 +65,7 @@ def test_memtable(prepare_duckdb_con, tmp_path_factory): profiles = {"default-duckdb": backend} - compiler = IbisYamlCompiler() - compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler = IbisYamlCompiler(build_dir) compiler.profiles = profiles yaml_dict = compiler.compile_to_yaml(expr) @@ -77,7 +76,7 @@ def test_memtable(prepare_duckdb_con, tmp_path_factory): assert expr.execute().equals(roundtrip_expr.execute()) -def test_into_backend(prepare_duckdb_con, tmp_path_factory): +def test_into_backend(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() backend.profile_name = "default-duckdb" @@ -97,8 +96,7 @@ def test_into_backend(prepare_duckdb_con, tmp_path_factory): "default-datafusion": con3, } - compiler = IbisYamlCompiler() - compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler = IbisYamlCompiler(build_dir) compiler.profiles = profiles yaml_dict = compiler.compile_to_yaml(expr) @@ -107,7 +105,7 @@ def test_into_backend(prepare_duckdb_con, tmp_path_factory): assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) -def test_memtable_cache(prepare_duckdb_con, tmp_path_factory): +def test_memtable_cache(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() backend.profile_name = "default-duckdb" @@ -117,8 +115,7 @@ def test_memtable_cache(prepare_duckdb_con, tmp_path_factory): profiles = {"default-duckdb": backend, "default-let": backend1} - compiler = IbisYamlCompiler() - compiler.tmp_path = tmp_path_factory.mktemp("duckdb") + compiler = IbisYamlCompiler(build_dir) compiler.profiles = profiles yaml_dict = compiler.compile_to_yaml(expr) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index d63a4261..2ada25c0 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -333,14 +333,14 @@ def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: @translate_to_yaml.register(ops.InMemoryTable) def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: - if not hasattr(compiler, "tmp_path"): + if not hasattr(compiler, "current_path"): raise ValueError( - "Compiler is missing the 'tmp_path' attribute for memtable serialization" + "Compiler is missing the 'current_path' attribute for memtable serialization" ) arrow_table = op.data.to_pyarrow(op.schema) - file_path = compiler.tmp_path / f"memtable_{id(op)}.parquet" + file_path = compiler.current_path / f"memtable_{id(op)}.parquet" pq.write_table(arrow_table, str(file_path)) return freeze( From ae7a275d2d51d9ecc4b4c1bd6a892ab8f2084dfd Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 15 Feb 2025 20:04:45 -0500 Subject: [PATCH 07/45] feat: add sql plan generation for yaml serialization --- python/letsql/ibis_yaml/compiler.py | 10 +- python/letsql/ibis_yaml/sql.py | 95 +++++++++ .../letsql/ibis_yaml/tests/test_compiler.py | 67 +++++- python/letsql/ibis_yaml/tests/test_sql.py | 193 ++++++++++++++++++ 4 files changed, 361 insertions(+), 4 deletions(-) create mode 100644 python/letsql/ibis_yaml/sql.py create mode 100644 python/letsql/ibis_yaml/tests/test_sql.py diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 71a2c1ab..bd8452a8 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -7,6 +7,7 @@ import yaml from ibis.common.collections import FrozenOrderedDict +from letsql.ibis_yaml.sql import generate_sql_plans from letsql.ibis_yaml.translate import translate_from_yaml, translate_to_yaml @@ -86,7 +87,12 @@ def get_expr_hash(self, expr) -> str: return expr_hash[:12] # TODO: make length of hash as a config def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash) -> pathlib.Path: - return self.storage.write_yaml(yaml_dict, expr_hash, "expr.yaml") + filename = ( + "sql.yaml" + if isinstance(yaml_dict, dict) and "queries" in yaml_dict + else "expr.yaml" + ) + return self.storage.write_yaml(yaml_dict, expr_hash, filename) def load_yaml(self, expr_hash: str) -> Dict[str, Any]: return self.storage.read_yaml(expr_hash, "expr.yaml") @@ -105,6 +111,8 @@ def compile(self, expr): expr_hash = self.build_manager.get_expr_hash(expr) self.curent_path = self.build_manager.get_build_path(expr_hash) self.build_manager.save_yaml(yaml_dict, expr_hash) + plans = generate_sql_plans(expr) + self.build_manager.save_yaml(plans, expr_hash) def from_hash(self, expr_hash) -> ir.Expr: yaml_dict = self.build_manager.load_yaml(expr_hash) diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py new file mode 100644 index 00000000..d25b94b7 --- /dev/null +++ b/python/letsql/ibis_yaml/sql.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, TypedDict + +import ibis +import ibis.expr.operations as ops +import ibis.expr.types as ir + +from letsql.expr.relations import RemoteTable + + +class QueryInfo(TypedDict): + engine: str + profile_name: str + sql: str + + +class SQLPlans(TypedDict): + queries: Dict[str, QueryInfo] + + +def find_remote_tables(op) -> Dict[str, Dict[str, Any]]: + remote_tables = {} + seen = set() + + def traverse(node): + if node is None or id(node) in seen: + return + + seen.add(id(node)) + + if isinstance(node, ops.Node) and isinstance(node, RemoteTable): + remote_expr = node.remote_expr + original_backend = remote_expr._find_backend() + if ( + not hasattr(original_backend, "profile_name") + or original_backend.profile_name is None + ): + raise AttributeError( + "Backend does not have a valid 'profile_name' attribute." + ) + + engine_name = original_backend.name + profile_name = original_backend.profile_name + remote_tables[node.name] = { + "engine": engine_name, + "profile_name": profile_name, + "sql": ibis.to_sql(remote_expr), + } + + if isinstance(node, ops.Node): + for arg in node.args: + if isinstance(arg, ops.Node): + traverse(arg) + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, ops.Node): + traverse(item) + elif isinstance(arg, dict): + for v in arg.values(): + if isinstance(v, ops.Node): + traverse(v) + + traverse(op) + return remote_tables + + +def generate_sql_plans(expr: ir.Expr) -> SQLPlans: + remote_tables = find_remote_tables(expr.op()) + + main_sql = ibis.to_sql(expr) + backend = expr._find_backend() + + if not hasattr(backend, "profile_name") or backend.profile_name is None: + raise AttributeError("Backend does not have a valid 'profile_name' attribute.") + + engine_name = backend.name + profile_name = backend.profile_name + + plans: SQLPlans = { + "queries": { + "main": { + "engine": engine_name, + "profile_name": profile_name, + "sql": main_sql.strip(), + } + } + } + + for table_name, info in remote_tables.items(): + plans["queries"][table_name] = { + "engine": info["engine"], + "profile_name": info["profile_name"], + "sql": info["sql"].strip(), + } + + return plans diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index 61afc7c8..cd077563 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -1,7 +1,10 @@ import os +import pathlib +import dask from ibis.common.collections import FrozenOrderedDict +import letsql as ls from letsql.ibis_yaml.compiler import BuildManager, IbisYamlCompiler @@ -22,6 +25,8 @@ def test_build_manager_roundtrip(t, build_dir): out = f.read() assert out == "a: string\n" result = build_manager.load_yaml(expr_hash) + + # assert os.path.exists(build_dir/ expr_hash / "sql.yaml") assert result == yaml_dict @@ -55,10 +60,66 @@ def test_clean_frozen_dict_yaml(build_dir): def test_ibis_compiler(t, build_dir): + t = ls.memtable({"a": [0, 1], "b": [0, 1]}) + backend = t._find_backend() + backend.profile_name = "default" + expr = t.filter(t.a == 1).drop("b") compiler = IbisYamlCompiler(build_dir) - compiler.compile(t) - expr_hash = "c6527994ad9a" + compiler.profiles = {"default": backend} + compiler.compile(expr) + expr_hash = dask.base.tokenize(expr)[:12] + + roundtrip_expr = compiler.from_hash(expr_hash) + + assert expr.execute().equals(roundtrip_expr.execute()) + + +def test_ibis_compiler_parquet_reader(t, build_dir): + backend = ls.datafusion.connect() + backend.profile_name = "default" + awards_players = backend.read_parquet( + ls.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + compiler = IbisYamlCompiler(build_dir) + compiler.profiles = {"default": backend} + compiler.compile(expr) + expr_hash = "5ebaf6a7a02d" roundtrip_expr = compiler.from_hash(expr_hash) - assert t.equals(roundtrip_expr) + assert expr.execute().equals(roundtrip_expr.execute()) + + +def test_compiler_sql(build_dir): + backend = ls.datafusion.connect() + backend.profile_name = "default" + awards_players = backend.read_parquet( + ls.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + + compiler = IbisYamlCompiler(build_dir) + compiler.profiles = {"default": backend} + compiler.compile(expr) + expr_hash = "5ebaf6a7a02d" + _roundtrip_expr = compiler.from_hash(expr_hash) + + assert os.path.exists(build_dir / expr_hash / "sql.yaml") + + sql_text = pathlib.Path(build_dir / expr_hash / "sql.yaml").read_text() + expected_result = ( + "queries:\n" + " main:\n" + " engine: datafusion\n" + " profile_name: default\n" + ' sql: "SELECT\\n \\"t0\\".\\"playerID\\",\\n ' + '\\"t0\\".\\"awardID\\",\\n \\"t0\\".\\"tie\\"\\\n' + ' ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players\\" AS ' + '\\"t0\\"\\nWHERE\\n \\"t0\\".\\"\\\n' + " lgID\\\" = 'NL'\"\n" + ) + + assert sql_text == expected_result diff --git a/python/letsql/ibis_yaml/tests/test_sql.py b/python/letsql/ibis_yaml/tests/test_sql.py new file mode 100644 index 00000000..dc735217 --- /dev/null +++ b/python/letsql/ibis_yaml/tests/test_sql.py @@ -0,0 +1,193 @@ +import ibis.expr.operations as ops +import pytest + +import letsql as ls +from letsql.expr.relations import RemoteTable, into_backend +from letsql.ibis_yaml.sql import find_remote_tables, generate_sql_plans + + +def test_find_remote_tables_simple(): + db = ls.duckdb.connect() + db.profile_name = "duckdb" + table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) + backend = table._find_backend() + backend.profile_name = "duckdb" + remote_expr = into_backend(table, db) + + remote_tables = find_remote_tables(remote_expr.op()) + + assert len(remote_tables) == 1 + table_name = next(iter(remote_tables)) + assert table_name.startswith("ibis_remote") + assert remote_tables[table_name]["engine"] == "duckdb" + + +def test_find_remote_tables_raises(): + db = ls.connect() + + awards_players = db.read_parquet( + ls.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + + db2 = ls.datafusion.connect() + + remote_expr = into_backend(awards_players, db2) + with pytest.raises( + AttributeError, match="Backend does not have a valid 'profile_name' attribute." + ): + find_remote_tables(remote_expr.op()) + + +def test_find_remote_tables_nested(): + db1 = ls.duckdb.connect() + db1.profile_name = "duckdb" + db2 = ls.datafusion.connect() + db2.profile_name = "datafusion" + + table1 = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val1"]) + table2 = ls.memtable([(1, "x"), (2, "y")], columns=["id", "val2"]) + + remote1 = into_backend(table1, db1) + remote2 = into_backend(table2, db2) + expr = remote1.join(remote2, "id") + + remote_tables = find_remote_tables(expr.op()) + + assert len(remote_tables) == 2 + assert all(name.startswith("ibis_remote") for name in remote_tables) + assert all("engine" in info and "sql" in info for info in remote_tables.values()) + + +def test_find_remote_tables(): + pg = ls.postgres.connect_examples() + pg.profile_name = "postgres" + db = ls.duckdb.connect() + db.profile_name = "duckdb" + + batting = pg.table("batting") + awards_players = db.read_parquet( + ls.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + + left = batting.filter(batting.yearID == 2015) + right = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + expr = left.join(into_backend(right, pg), ["playerID"], how="semi")[ + ["yearID", "stint"] + ] + + def print_tree(node, level=0): + indent = " " * level + print(f"{indent}{type(node).__name__}") + if hasattr(node, "args"): + for arg in node.args: + if isinstance(arg, (ops.Node, RemoteTable)): + print_tree(arg, level + 1) + + print_tree(expr.op()) + + remote_tables = find_remote_tables(expr.op()) + + assert len(remote_tables) == 1, ( + f"Expected 1 remote table, found {len(remote_tables)}" + ) + + first_table = next(iter(remote_tables.values())) + assert "sql" in first_table, "SQL query missing from remote table info" + assert "engine" in first_table, "Engine info missing from remote table info" + + +def test_generate_sql_plans_simple(): + db = ls.duckdb.connect() + db.profile_name = "duckdb" + table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) + expr = into_backend(table, db).filter(ls._.id > 1) + + plans = generate_sql_plans(expr) + + assert "queries" in plans + assert "main" in plans["queries"] + assert len(plans["queries"]) == 2 + assert all("sql" in q and "engine" in q for q in plans["queries"].values()) + + +def test_generate_sql_plans_raises(): + db = ls.duckdb.connect() + table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) + expr = into_backend(table, db).filter(ls._.id > 1) + with pytest.raises( + AttributeError, match="Backend does not have a valid 'profile_name' attribute." + ): + generate_sql_plans(expr) + + +def test_generate_sql_plans_complex_example(): + pg = ls.postgres.connect_examples() + pg.profile_name = "postgres" + + db = ls.duckdb.connect() + db.profile_name = "duckdb" + + batting = pg.table("batting") + awards_players = db.read_parquet( + ls.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + + left = batting.filter(batting.yearID == 2015) + right = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + expr = left.join(into_backend(right, pg), ["playerID"], how="semi")[ + ["yearID", "stint"] + ] + + plans = generate_sql_plans(expr) + + assert "queries" in plans + assert len(plans["queries"]) == 2 + assert "main" in plans["queries"] + + remote_table_names = [k for k in plans["queries"].keys() if k != "main"] + assert len(remote_table_names) == 1 + remote_table_name = remote_table_names[0] + assert remote_table_name.startswith("ibis_remote") + + expected_main_sql = f'''SELECT + "t4"."yearID", + "t4"."stint" +FROM ( + SELECT + * + FROM "batting" AS "t0" + WHERE + "t0"."yearID" = 2015 +) AS "t4" +WHERE + EXISTS( + SELECT + 1 + FROM "{remote_table_name}" AS "t2" + WHERE + "t4"."playerID" = "t2"."playerID" + )''' + + expected_remote_sql = """SELECT + "t0"."playerID", + "t0"."awardID", + "t0"."tie", + "t0"."notes" +FROM "awards_players" AS "t0" +WHERE + "t0"."lgID" = 'NL\'""" + + main_query = plans["queries"]["main"] + assert main_query["engine"] == "postgres", ( + f"Expected 'postgres', got '{main_query['engine']}'" + ) + assert main_query["sql"].strip() == expected_main_sql.strip() + + remote_query = plans["queries"][remote_table_name] + assert remote_query["engine"] == "duckdb", ( + f"Expected 'duckdb', got '{remote_query['engine']}'" + ) + assert remote_query["sql"].strip() == expected_remote_sql.strip() From 27b3be6f5561f8d215643415996f591acb3aa09e Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 16 Feb 2025 12:49:04 -0500 Subject: [PATCH 08/45] feat: add seperate schema definitions in yaml - removes redundant schema definitions in yaml --- python/letsql/ibis_yaml/compiler.py | 23 ++- .../letsql/ibis_yaml/tests/test_arithmetic.py | 63 +++--- python/letsql/ibis_yaml/tests/test_basic.py | 65 +++--- .../tests/test_operations_boolean.py | 63 +++--- .../ibis_yaml/tests/test_operations_cast.py | 32 +-- .../tests/test_operations_datetime.py | 59 +++--- .../letsql/ibis_yaml/tests/test_relations.py | 46 ++--- .../letsql/ibis_yaml/tests/test_string_ops.py | 10 +- .../letsql/ibis_yaml/tests/test_subquery.py | 15 +- .../ibis_yaml/tests/test_window_functions.py | 7 +- python/letsql/ibis_yaml/translate.py | 189 ++++++++++-------- python/letsql/ibis_yaml/utils.py | 11 +- 12 files changed, 328 insertions(+), 255 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index bd8452a8..47810656 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -8,9 +8,15 @@ from ibis.common.collections import FrozenOrderedDict from letsql.ibis_yaml.sql import generate_sql_plans -from letsql.ibis_yaml.translate import translate_from_yaml, translate_to_yaml +from letsql.ibis_yaml.translate import ( + SchemaRegistry, + translate_from_yaml, + translate_to_yaml, +) +from letsql.ibis_yaml.utils import freeze +# is this the right way to handle this? or the right place class CleanDictYAMLDumper(yaml.SafeDumper): def represent_frozenordereddict(self, data): return self.represent_dict(dict(data)) @@ -104,6 +110,7 @@ def get_build_path(self, expr_hash: str) -> pathlib.Path: class IbisYamlCompiler: def __init__(self, build_dir, build_manager=BuildManager): self.build_manager = build_manager(build_dir) + self.schema_registry = SchemaRegistry() self.current_path = None def compile(self, expr): @@ -117,7 +124,6 @@ def compile(self, expr): def from_hash(self, expr_hash) -> ir.Expr: yaml_dict = self.build_manager.load_yaml(expr_hash) - # this is needed for cache to work with ForzenOrderedDict def convert_to_frozen(d): if isinstance(d, dict): items = [] @@ -135,9 +141,18 @@ def convert_to_frozen(d): return self.compile_from_yaml(yaml_dict) def compile_from_yaml(self, yaml_dict) -> ir.Expr: - return translate_from_yaml(yaml_dict, self) + self.definitions = yaml_dict.get("definitions", {}) + return translate_from_yaml(yaml_dict["expression"], self) def compile_to_yaml(self, expr) -> Dict: expr_hash = self.build_manager.get_expr_hash(expr) self.current_path = self.build_manager.get_build_path(expr_hash) - return translate_to_yaml(expr.op(), self) + expr_yaml = translate_to_yaml(expr, self) + + schema_definitions = {} + for schema_id, schema in self.schema_registry.schemas.items(): + schema_definitions[schema_id] = schema + + return freeze( + {"definitions": {"schemas": schema_definitions}, "expression": expr_yaml} + ) diff --git a/python/letsql/ibis_yaml/tests/test_arithmetic.py b/python/letsql/ibis_yaml/tests/test_arithmetic.py index 4722f680..2fc6a11d 100644 --- a/python/letsql/ibis_yaml/tests/test_arithmetic.py +++ b/python/letsql/ibis_yaml/tests/test_arithmetic.py @@ -7,12 +7,13 @@ def test_add(compiler): expr = lit1 + lit2 yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Add" - assert yaml_dict["args"][0]["op"] == "Literal" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["op"] == "Literal" - assert yaml_dict["args"][1]["value"] == 3 - assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Add" + assert expression["args"][0]["op"] == "Literal" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["op"] == "Literal" + assert expression["args"][1]["value"] == 3 + assert expression["type"] == {"name": "Int8", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -24,12 +25,13 @@ def test_subtract(compiler): expr = lit1 - lit2 yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Subtract" - assert yaml_dict["args"][0]["op"] == "Literal" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["op"] == "Literal" - assert yaml_dict["args"][1]["value"] == 3 - assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Subtract" + assert expression["args"][0]["op"] == "Literal" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["op"] == "Literal" + assert expression["args"][1]["value"] == 3 + assert expression["type"] == {"name": "Int8", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -41,12 +43,13 @@ def test_multiply(compiler): expr = lit1 * lit2 yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Multiply" - assert yaml_dict["args"][0]["op"] == "Literal" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["op"] == "Literal" - assert yaml_dict["args"][1]["value"] == 3 - assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Multiply" + assert expression["args"][0]["op"] == "Literal" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["op"] == "Literal" + assert expression["args"][1]["value"] == 3 + assert expression["type"] == {"name": "Int8", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -58,12 +61,13 @@ def test_divide(compiler): expr = lit1 / lit2 yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Divide" - assert yaml_dict["args"][0]["op"] == "Literal" - assert yaml_dict["args"][0]["value"] == 6.0 - assert yaml_dict["args"][1]["op"] == "Literal" - assert yaml_dict["args"][1]["value"] == 2.0 - assert yaml_dict["type"] == {"name": "Float64", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Divide" + assert expression["args"][0]["op"] == "Literal" + assert expression["args"][0]["value"] == 6.0 + assert expression["args"][1]["op"] == "Literal" + assert expression["args"][1]["value"] == 2.0 + assert expression["type"] == {"name": "Float64", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -75,8 +79,9 @@ def test_mixed_arithmetic(compiler): expr = i * f yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Multiply" - assert yaml_dict["type"] == {"name": "Float64", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Multiply" + assert expression["type"] == {"name": "Float64", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -89,8 +94,10 @@ def test_complex_arithmetic(compiler): expr = (a + b) * c yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Multiply" - assert yaml_dict["args"][0]["op"] == "Add" + expression = yaml_dict["expression"] + + assert expression["op"] == "Multiply" + assert expression["args"][0]["op"] == "Add" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_basic.py b/python/letsql/ibis_yaml/tests/test_basic.py index a8b40e59..85f20294 100644 --- a/python/letsql/ibis_yaml/tests/test_basic.py +++ b/python/letsql/ibis_yaml/tests/test_basic.py @@ -6,9 +6,10 @@ def test_unbound_table(t, compiler): yaml_dict = compiler.compile_to_yaml(t) - assert yaml_dict["op"] == "UnboundTable" - assert yaml_dict["name"] == "test_table" - assert yaml_dict["schema"]["a"] == {"name": "Int64", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "UnboundTable" + assert expression["name"] == "test_table" + assert expression["schema_ref"] roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.schema() == t.schema() @@ -18,9 +19,10 @@ def test_unbound_table(t, compiler): def test_field(t, compiler): expr = t.a yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Field" - assert yaml_dict["name"] == "a" - assert yaml_dict["type"] == {"name": "Int64", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Field" + assert expression["name"] == "a" + assert expression["type"] == {"name": "Int64", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -30,9 +32,11 @@ def test_field(t, compiler): def test_literal(compiler): lit = ibis.literal(42) yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["value"] == 42 - assert yaml_dict["type"] == {"name": "Int8", "nullable": True} + + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["value"] == 42 + assert expression["type"] == {"name": "Int8", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) @@ -41,9 +45,10 @@ def test_literal(compiler): def test_binary_op(t, compiler): expr = t.a + 1 yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Add" - assert yaml_dict["args"][0]["op"] == "Field" - assert yaml_dict["args"][1]["op"] == "Literal" + expression = yaml_dict["expression"] + assert expression["op"] == "Add" + assert expression["args"][0]["op"] == "Field" + assert expression["args"][1]["op"] == "Literal" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -60,8 +65,10 @@ def test_primitive_types(compiler): ] for lit, expected_type in primitives: yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["type"]["name"] == expected_type + + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["type"]["name"] == expected_type roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) @@ -79,8 +86,9 @@ def test_temporal_types(compiler): ] for lit, expected_type in temporals: yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["type"]["name"] == expected_type + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["type"]["name"] == expected_type roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) @@ -91,9 +99,10 @@ def test_decimal_type(compiler): dec = decimal.Decimal("123.45") lit = ibis.literal(dec) yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["type"]["name"] == "Decimal" - assert yaml_dict["type"]["nullable"] + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["type"]["name"] == "Decimal" + assert expression["type"]["nullable"] roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) @@ -103,10 +112,11 @@ def test_decimal_type(compiler): def test_array_type(compiler): lit = ibis.literal([1, 2, 3]) yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["type"]["name"] == "Array" - assert yaml_dict["type"]["value_type"]["name"] == "Int8" - assert yaml_dict["value"] == (1, 2, 3) + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["type"]["name"] == "Array" + assert expression["type"]["value_type"]["name"] == "Int8" + assert expression["value"] == (1, 2, 3) roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) @@ -116,10 +126,11 @@ def test_array_type(compiler): def test_map_type(compiler): lit = ibis.literal({"a": 1, "b": 2}) yaml_dict = compiler.compile_to_yaml(lit) - assert yaml_dict["op"] == "Literal" - assert yaml_dict["type"]["name"] == "Map" - assert yaml_dict["type"]["key_type"]["name"] == "String" - assert yaml_dict["type"]["value_type"]["name"] == "Int8" + expression = yaml_dict["expression"] + assert expression["op"] == "Literal" + assert expression["type"]["name"] == "Map" + assert expression["type"]["key_type"]["name"] == "String" + assert expression["type"]["value_type"]["name"] == "Int8" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) diff --git a/python/letsql/ibis_yaml/tests/test_operations_boolean.py b/python/letsql/ibis_yaml/tests/test_operations_boolean.py index 2e0bd936..4d0171e6 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_boolean.py +++ b/python/letsql/ibis_yaml/tests/test_operations_boolean.py @@ -6,10 +6,11 @@ def test_equals(compiler): b = ibis.literal(5) expr = a == b yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Equals" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["value"] == 5 - assert yaml_dict["type"] == {"name": "Boolean", "nullable": True} + expression = yaml_dict["expression"] + assert expression["op"] == "Equals" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["value"] == 5 + assert expression["type"] == {"name": "Boolean", "nullable": True} roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -19,9 +20,10 @@ def test_not_equals(compiler): b = ibis.literal(3) expr = a != b yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "NotEquals" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["value"] == 3 + expression = yaml_dict["expression"] + assert expression["op"] == "NotEquals" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["value"] == 3 roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -31,9 +33,10 @@ def test_greater_than(compiler): b = ibis.literal(3) expr = a > b yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Greater" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["value"] == 3 + expression = yaml_dict["expression"] + assert expression["op"] == "Greater" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["value"] == 3 roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -43,9 +46,10 @@ def test_less_than(compiler): b = ibis.literal(5) expr = a < b yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Less" - assert yaml_dict["args"][0]["value"] == 3 - assert yaml_dict["args"][1]["value"] == 5 + expression = yaml_dict["expression"] + assert expression["op"] == "Less" + assert expression["args"][0]["value"] == 3 + assert expression["args"][1]["value"] == 5 roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -57,17 +61,19 @@ def test_and_or(compiler): expr_and = (a > b) & (a < c) yaml_dict = compiler.compile_to_yaml(expr_and) - assert yaml_dict["op"] == "And" - assert yaml_dict["args"][0]["op"] == "Greater" - assert yaml_dict["args"][1]["op"] == "Less" + expression = yaml_dict["expression"] + assert expression["op"] == "And" + assert expression["args"][0]["op"] == "Greater" + assert expression["args"][1]["op"] == "Less" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr_and) expr_or = (a > b) | (a < c) yaml_dict = compiler.compile_to_yaml(expr_or) - assert yaml_dict["op"] == "Or" - assert yaml_dict["args"][0]["op"] == "Greater" - assert yaml_dict["args"][1]["op"] == "Less" + expression = yaml_dict["expression"] + assert expression["op"] == "Or" + assert expression["args"][0]["op"] == "Greater" + assert expression["args"][1]["op"] == "Less" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr_or) @@ -76,8 +82,9 @@ def test_not(compiler): a = ibis.literal(True) expr = ~a yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Not" - assert yaml_dict["args"][0]["value"] + expression = yaml_dict["expression"] + assert expression["op"] == "Not" + assert expression["args"][0]["value"] roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -86,8 +93,9 @@ def test_is_null(compiler): a = ibis.literal(None) expr = a.isnull() yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "IsNull" - assert yaml_dict["args"][0]["value"] is None + expression = yaml_dict["expression"] + assert expression["op"] == "IsNull" + assert expression["args"][0]["value"] is None roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -96,9 +104,10 @@ def test_between(compiler): a = ibis.literal(5) expr = a.between(3, 7) yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Between" - assert yaml_dict["args"][0]["value"] == 5 - assert yaml_dict["args"][1]["value"] == 3 - assert yaml_dict["args"][2]["value"] == 7 + expression = yaml_dict["expression"] + assert expression["op"] == "Between" + assert expression["args"][0]["value"] == 5 + assert expression["args"][1]["value"] == 3 + assert expression["args"][2]["value"] == 7 roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_cast.py b/python/letsql/ibis_yaml/tests/test_operations_cast.py index 4039ec26..38380739 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_cast.py +++ b/python/letsql/ibis_yaml/tests/test_operations_cast.py @@ -4,11 +4,12 @@ def test_explicit_cast(compiler): expr = ibis.literal(42).cast("float64") yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "Cast" - assert yaml_dict["args"][0]["op"] == "Literal" - assert yaml_dict["args"][0]["value"] == 42 - assert yaml_dict["type"]["name"] == "Float64" + assert expression["op"] == "Cast" + assert expression["args"][0]["op"] == "Literal" + assert expression["args"][0]["value"] == 42 + assert expression["type"]["name"] == "Float64" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -19,11 +20,12 @@ def test_implicit_cast(compiler): f = ibis.literal(2.5) expr = i + f yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "Add" - assert yaml_dict["args"][0]["type"]["name"] == "Int8" - assert yaml_dict["args"][1]["type"]["name"] == "Float64" - assert yaml_dict["type"]["name"] == "Float64" + assert expression["op"] == "Add" + assert expression["args"][0]["type"]["name"] == "Int8" + assert expression["args"][1]["type"]["name"] == "Float64" + assert expression["type"]["name"] == "Float64" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -32,10 +34,11 @@ def test_implicit_cast(compiler): def test_string_cast(compiler): expr = ibis.literal("42").cast("int64") yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "Cast" - assert yaml_dict["args"][0]["value"] == "42" - assert yaml_dict["type"]["name"] == "Int64" + assert expression["op"] == "Cast" + assert expression["args"][0]["value"] == "42" + assert expression["type"]["name"] == "Int64" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -44,10 +47,11 @@ def test_string_cast(compiler): def test_timestamp_cast(compiler): expr = ibis.literal("2024-01-01").cast("timestamp") yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "Cast" - assert yaml_dict["args"][0]["value"] == "2024-01-01" - assert yaml_dict["type"]["name"] == "Timestamp" + assert expression["op"] == "Cast" + assert expression["args"][0]["value"] == "2024-01-01" + assert expression["type"]["name"] == "Timestamp" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_datetime.py b/python/letsql/ibis_yaml/tests/test_operations_datetime.py index 34a3c063..ded9fc91 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_datetime.py +++ b/python/letsql/ibis_yaml/tests/test_operations_datetime.py @@ -12,21 +12,24 @@ def test_date_extract(compiler): year = dt_expr.year() year_yaml = compiler.compile_to_yaml(year) - assert year_yaml["op"] == "ExtractYear" - assert year_yaml["args"][0]["value"] == "2024-03-14T15:09:26" - assert year_yaml["type"]["name"] == "Int32" + expression = year_yaml["expression"] + assert expression["op"] == "ExtractYear" + assert expression["args"][0]["value"] == "2024-03-14T15:09:26" + assert expression["type"]["name"] == "Int32" roundtrip_year = compiler.compile_from_yaml(year_yaml) assert roundtrip_year.equals(year) month = dt_expr.month() month_yaml = compiler.compile_to_yaml(month) - assert month_yaml["op"] == "ExtractMonth" + expression = month_yaml["expression"] + assert expression["op"] == "ExtractMonth" roundtrip_month = compiler.compile_from_yaml(month_yaml) assert roundtrip_month.equals(month) day = dt_expr.day() day_yaml = compiler.compile_to_yaml(day) - assert day_yaml["op"] == "ExtractDay" + expression = day_yaml["expression"] + assert expression["op"] == "ExtractDay" roundtrip_day = compiler.compile_from_yaml(day_yaml) assert roundtrip_day.equals(day) @@ -36,21 +39,24 @@ def test_time_extract(compiler): hour = dt_expr.hour() hour_yaml = compiler.compile_to_yaml(hour) - assert hour_yaml["op"] == "ExtractHour" - assert hour_yaml["args"][0]["value"] == "2024-03-14T15:09:26" - assert hour_yaml["type"]["name"] == "Int32" + hour_expression = hour_yaml["expression"] + assert hour_expression["op"] == "ExtractHour" + assert hour_expression["args"][0]["value"] == "2024-03-14T15:09:26" + assert hour_expression["type"]["name"] == "Int32" roundtrip_hour = compiler.compile_from_yaml(hour_yaml) assert roundtrip_hour.equals(hour) minute = dt_expr.minute() minute_yaml = compiler.compile_to_yaml(minute) - assert minute_yaml["op"] == "ExtractMinute" + minute_expression = minute_yaml["expression"] + assert minute_expression["op"] == "ExtractMinute" roundtrip_minute = compiler.compile_from_yaml(minute_yaml) assert roundtrip_minute.equals(minute) second = dt_expr.second() second_yaml = compiler.compile_to_yaml(second) - assert second_yaml["op"] == "ExtractSecond" + second_expression = second_yaml["expression"] + assert second_expression["op"] == "ExtractSecond" roundtrip_second = compiler.compile_from_yaml(second_yaml) assert roundtrip_second.equals(second) @@ -61,17 +67,19 @@ def test_timestamp_arithmetic(compiler): plus_day = ts + delta yaml_dict = compiler.compile_to_yaml(plus_day) - assert yaml_dict["op"] == "TimestampAdd" - assert yaml_dict["type"]["name"] == "Timestamp" - assert yaml_dict["args"][1]["type"]["name"] == "Interval" + expression = yaml_dict["expression"] + assert expression["op"] == "TimestampAdd" + assert expression["type"]["name"] == "Timestamp" + assert expression["args"][1]["type"]["name"] == "Interval" roundtrip_plus = compiler.compile_from_yaml(yaml_dict) assert roundtrip_plus.equals(plus_day) minus_day = ts - delta yaml_dict = compiler.compile_to_yaml(minus_day) - assert yaml_dict["op"] == "TimestampSub" - assert yaml_dict["type"]["name"] == "Timestamp" - assert yaml_dict["args"][1]["type"]["name"] == "Interval" + expression = yaml_dict["expression"] + assert expression["op"] == "TimestampSub" + assert expression["type"]["name"] == "Timestamp" + assert expression["args"][1]["type"]["name"] == "Interval" roundtrip_minus = compiler.compile_from_yaml(yaml_dict) assert roundtrip_minus.equals(minus_day) @@ -81,8 +89,9 @@ def test_timestamp_diff(compiler): ts2 = ibis.literal(datetime(2024, 3, 15)) diff = ts2 - ts1 yaml_dict = compiler.compile_to_yaml(diff) - assert yaml_dict["op"] == "TimestampDiff" - assert yaml_dict["type"]["name"] == "Interval" + expression = yaml_dict["expression"] + assert expression["op"] == "TimestampDiff" + assert expression["type"]["name"] == "Interval" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(diff) @@ -90,16 +99,18 @@ def test_timestamp_diff(compiler): def test_temporal_unit_yaml(compiler): interval_date = ibis.literal(5, type=dt.Interval(unit=tm.DateUnit("D"))) yaml_date = compiler.compile_to_yaml(interval_date) - assert yaml_date["type"]["name"] == "Interval" - assert yaml_date["type"]["unit"]["name"] == "DateUnit" - assert yaml_date["type"]["unit"]["value"] == "D" + expression_date = yaml_date["expression"] + assert expression_date["type"]["name"] == "Interval" + assert expression_date["type"]["unit"]["name"] == "DateUnit" + assert expression_date["type"]["unit"]["value"] == "D" roundtrip_date = compiler.compile_from_yaml(yaml_date) assert roundtrip_date.equals(interval_date) interval_time = ibis.literal(10, type=dt.Interval(unit=tm.TimeUnit("h"))) yaml_time = compiler.compile_to_yaml(interval_time) - assert yaml_time["type"]["name"] == "Interval" - assert yaml_time["type"]["unit"]["name"] == "TimeUnit" - assert yaml_time["type"]["unit"]["value"] == "h" + expression_time = yaml_time["expression"] + assert expression_time["type"]["name"] == "Interval" + assert expression_time["type"]["unit"]["name"] == "TimeUnit" + assert expression_time["type"]["unit"]["value"] == "h" roundtrip_time = compiler.compile_from_yaml(yaml_time) assert roundtrip_time.equals(interval_time) diff --git a/python/letsql/ibis_yaml/tests/test_relations.py b/python/letsql/ibis_yaml/tests/test_relations.py index c14a5799..19fe9b57 100644 --- a/python/letsql/ibis_yaml/tests/test_relations.py +++ b/python/letsql/ibis_yaml/tests/test_relations.py @@ -4,11 +4,12 @@ def test_filter(compiler, t): expr = t.filter(t.a > 0) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] # Original assertions - assert yaml_dict["op"] == "Filter" - assert yaml_dict["predicates"][0]["op"] == "Greater" - assert yaml_dict["parent"]["op"] == "UnboundTable" + assert expression["op"] == "Filter" + assert expression["predicates"][0]["op"] == "Greater" + assert expression["parent"]["op"] == "UnboundTable" # Roundtrip test: compile from YAML and verify equality roundtrip_expr = compiler.compile_from_yaml(yaml_dict) @@ -18,11 +19,12 @@ def test_filter(compiler, t): def test_projection(compiler, t): expr = t.select(["a", "b"]) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] # Original assertions - assert yaml_dict["op"] == "Project" - assert yaml_dict["parent"]["op"] == "UnboundTable" - assert set(yaml_dict["values"]) == {"a", "b"} + assert expression["op"] == "Project" + assert expression["parent"]["op"] == "UnboundTable" + assert set(expression["values"]) == {"a", "b"} # Roundtrip test roundtrip_expr = compiler.compile_from_yaml(yaml_dict) @@ -32,11 +34,11 @@ def test_projection(compiler, t): def test_aggregation(compiler, t): expr = t.group_by("a").aggregate(avg_c=t.c.mean()) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - # Original assertions - assert yaml_dict["op"] == "Aggregate" - assert yaml_dict["by"][0]["name"] == "a" - assert yaml_dict["metrics"]["avg_c"]["op"] == "Mean" + assert expression["op"] == "Aggregate" + assert expression["by"][0]["name"] == "a" + assert expression["metrics"]["avg_c"]["op"] == "Mean" # Roundtrip test roundtrip_expr = compiler.compile_from_yaml(yaml_dict) @@ -48,12 +50,11 @@ def test_join(compiler): t2 = ibis.table(dict(b="string", c="float"), name="t2") expr = t1.join(t2, t1.b == t2.b) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - # Original assertions - assert yaml_dict["op"] == "JoinChain" - # The first join link's predicates - assert yaml_dict["rest"][0]["predicates"][0]["op"] == "Equals" - assert yaml_dict["rest"][0]["how"] == "inner" + assert expression["op"] == "JoinChain" + assert expression["rest"][0]["predicates"][0]["op"] == "Equals" + assert expression["rest"][0]["how"] == "inner" # Roundtrip test roundtrip_expr = compiler.compile_from_yaml(yaml_dict) @@ -63,12 +64,11 @@ def test_join(compiler): def test_order_by(compiler, t): expr = t.order_by(["a", "b"]) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - # Original assertions - assert yaml_dict["op"] == "Sort" - assert len(yaml_dict["keys"]) == 2 + assert expression["op"] == "Sort" + assert len(expression["keys"]) == 2 - # Roundtrip test roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -76,11 +76,9 @@ def test_order_by(compiler, t): def test_limit(compiler, t): expr = t.limit(10) yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - # Original assertions - assert yaml_dict["op"] == "Limit" - assert yaml_dict["n"] == 10 - - # Roundtrip test + assert expression["op"] == "Limit" + assert expression["n"] == 10 roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_string_ops.py b/python/letsql/ibis_yaml/tests/test_string_ops.py index 74adbfac..bbfdc6f3 100644 --- a/python/letsql/ibis_yaml/tests/test_string_ops.py +++ b/python/letsql/ibis_yaml/tests/test_string_ops.py @@ -5,7 +5,7 @@ def test_string_concat(compiler): s1 = ibis.literal("hello") s2 = ibis.literal("world") expr = s1 + s2 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.compile_to_yaml(expr)["expression"] assert yaml_dict["op"] == "StringConcat" assert yaml_dict["args"][0]["value"] == "hello" @@ -18,11 +18,11 @@ def test_string_upper_lower(compiler): upper_expr = s.upper() lower_expr = s.lower() - upper_yaml = compiler.compile_to_yaml(upper_expr) + upper_yaml = compiler.compile_to_yaml(upper_expr)["expression"] assert upper_yaml["op"] == "Uppercase" assert upper_yaml["args"][0]["value"] == "Hello" - lower_yaml = compiler.compile_to_yaml(lower_expr) + lower_yaml = compiler.compile_to_yaml(lower_expr)["expression"] assert lower_yaml["op"] == "Lowercase" assert lower_yaml["args"][0]["value"] == "Hello" @@ -30,7 +30,7 @@ def test_string_upper_lower(compiler): def test_string_length(compiler): s = ibis.literal("hello") expr = s.length() - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.compile_to_yaml(expr)["expression"] assert yaml_dict["op"] == "StringLength" assert yaml_dict["args"][0]["value"] == "hello" @@ -40,7 +40,7 @@ def test_string_length(compiler): def test_string_substring(compiler): s = ibis.literal("hello world") expr = s.substr(0, 5) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.compile_to_yaml(expr)["expression"] assert yaml_dict["op"] == "Substring" assert yaml_dict["args"][0]["value"] == "hello world" diff --git a/python/letsql/ibis_yaml/tests/test_subquery.py b/python/letsql/ibis_yaml/tests/test_subquery.py index 961e68a8..d9056aa2 100644 --- a/python/letsql/ibis_yaml/tests/test_subquery.py +++ b/python/letsql/ibis_yaml/tests/test_subquery.py @@ -5,9 +5,10 @@ def test_scalar_subquery(compiler, t): expr = ops.ScalarSubquery(t.c.mean().as_table()).to_expr() yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "ScalarSubquery" - assert yaml_dict["args"][0]["op"] == "Aggregate" + assert expression["op"] == "ScalarSubquery" + assert expression["args"][0]["op"] == "Aggregate" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -20,9 +21,10 @@ def test_exists_subquery(compiler): filtered = t2.filter(t2.a == t1.a) expr = ops.ExistsSubquery(filtered).to_expr() yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "ExistsSubquery" - assert yaml_dict["rel"]["op"] == "Filter" + assert expression["op"] == "ExistsSubquery" + assert expression["rel"]["op"] == "Filter" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -34,9 +36,10 @@ def test_in_subquery(compiler): expr = ops.InSubquery(t1.select("a"), t2.a).to_expr() yaml_dict = compiler.compile_to_yaml(expr) + expression = yaml_dict["expression"] - assert yaml_dict["op"] == "InSubquery" - assert yaml_dict["type"]["name"] == "Boolean" + assert expression["op"] == "InSubquery" + assert expression["type"]["name"] == "Boolean" roundtrip_expr = compiler.compile_from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/letsql/ibis_yaml/tests/test_window_functions.py index 805aa52d..36b03253 100644 --- a/python/letsql/ibis_yaml/tests/test_window_functions.py +++ b/python/letsql/ibis_yaml/tests/test_window_functions.py @@ -37,8 +37,9 @@ def test_aggregation_window(compiler, t): ) yaml_dict = compiler.compile_to_yaml(expr) - assert yaml_dict["op"] == "Project" - window_func = yaml_dict["values"]["mean_c"] + expression = yaml_dict["expression"] + assert expression["op"] == "Project" + window_func = expression["values"]["mean_c"] assert window_func["op"] == "WindowFunction" assert window_func["args"][0]["op"] == "Mean" @@ -52,6 +53,4 @@ def test_aggregation_window(compiler, t): else: assert window_func["end"] == following - print(yaml_dict) - assert window_func["group_by"][0]["name"] == "a" diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 2ada25c0..74a97029 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -3,7 +3,6 @@ import datetime import decimal import functools -import pathlib from typing import Any import ibis @@ -14,7 +13,6 @@ import ibis.expr.types as ir import pyarrow.parquet as pq from ibis.common.annotations import Argument -from ibis.common.exceptions import IbisTypeError import letsql as ls from letsql.expr.relations import CachedNode, RemoteTable, into_backend @@ -30,6 +28,26 @@ FROM_YAML_HANDLERS: dict[str, Any] = {} +class SchemaRegistry: + def __init__(self): + self.schemas = {} + self.counter = 0 + + def register_schema(self, schema): + frozen_schema = freeze( + {name: _translate_type(dtype) for name, dtype in schema.items()} + ) + + for schema_id, existing_schema in self.schemas.items(): + if existing_schema == frozen_schema: + return schema_id + + schema_id = f"schema_{self.counter}" + self.schemas[schema_id] = frozen_schema + self.counter += 1 + return schema_id + + def register_from_yaml_handler(*op_names: str): def decorator(func): for name in op_names: @@ -240,13 +258,12 @@ def _base_op_to_yaml(op: ops.Node, compiler: Any) -> dict: @translate_to_yaml.register(ops.UnboundTable) def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: + schema_id = compiler.schema_registry.register_schema(op.schema) return freeze( { "op": "UnboundTable", "name": op.name, - "schema": { - name: _translate_type(dtype) for name, dtype in op.schema.items() - }, + "schema_ref": schema_id, } ) @@ -254,52 +271,65 @@ def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: @register_from_yaml_handler("UnboundTable") def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: table_name = yaml_dict["name"] - schema = [ - (name, _type_from_yaml(dtype)) for name, dtype in yaml_dict["schema"].items() - ] + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + + schema_ref = yaml_dict["schema_ref"] + try: + schema_def = compiler.definitions["schemas"][schema_ref] + except KeyError: + raise ValueError(f"Schema {schema_ref} not found in definitions") + + schema = { + name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() + } return ibis.table(schema, name=table_name) @translate_to_yaml.register(ops.DatabaseTable) def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: profile_name = getattr(op.source, "profile_name", None) + schema_id = compiler.schema_registry.register_schema(op.schema) + return freeze( { "op": "DatabaseTable", "table": op.name, - "schema": { - name: _translate_type(dtype) for name, dtype in op.schema.items() - }, + "schema_ref": schema_id, "profile": profile_name, } ) @register_from_yaml_handler("DatabaseTable") -def _database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: +def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: profile_name = yaml_dict.get("profile") table_name = yaml_dict.get("table") - if not profile_name or not table_name: - raise ValueError( - "Missing 'profile' or 'table' information in YAML for DatabaseTable." - ) + # we should validate that schema is the same + schema_ref = yaml_dict.get("schema_ref") + + if not all([profile_name, table_name, schema_ref]): + raise ValueError("Missing required information in YAML for DatabaseTable.") try: con = compiler.profiles[profile_name] except KeyError: raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + return con.table(table_name) @translate_to_yaml.register(CachedNode) def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: + schema_id = compiler.schema_registry.register_schema(op.schema) + return freeze( { "op": "CachedNode", - "schema": { - name: _translate_type(dtype) for name, dtype in op.schema.items() - }, + "schema_ref": schema_id, "parent": translate_to_yaml(op.parent, compiler), "source": getattr(op.source, "profile_name", None), "storage": translate_storage(op.storage, compiler), @@ -310,10 +340,19 @@ def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: @register_from_yaml_handler("CachedNode") def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + + schema_ref = yaml_dict["schema_ref"] + try: + schema_def = compiler.definitions["schemas"][schema_ref] + except KeyError: + raise ValueError(f"Schema {schema_ref} not found in definitions") + schema = { - name: _type_from_yaml(dtype_yaml) - for name, dtype_yaml in yaml_dict["schema"].items() + name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() } + parent_expr = translate_from_yaml(yaml_dict["parent"], compiler) profile_name = yaml_dict.get("source") try: @@ -339,22 +378,42 @@ def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: ) arrow_table = op.data.to_pyarrow(op.schema) - file_path = compiler.current_path / f"memtable_{id(op)}.parquet" pq.write_table(arrow_table, str(file_path)) + # probably do not need to store schema + schema_id = compiler.schema_registry.register_schema(op.schema) return freeze( { "op": "InMemoryTable", "table": op.name, - "schema": { - name: _translate_type(dtype) for name, dtype in op.schema.items() - }, + "schema_ref": schema_id, "file": str(file_path), } ) +@register_from_yaml_handler("InMemoryTable") +def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + + file_path = yaml_dict["file"] + schema_ref = yaml_dict["schema_ref"] + try: + schema_def = compiler.definitions["schemas"][schema_ref] + except KeyError: + raise ValueError(f"Schema {schema_ref} not found in definitions") + + arrow_table = pq.read_table(file_path) + df = arrow_table.to_pandas() + table_name = yaml_dict.get("table", "memtable") + + column_names = list(schema_def.keys()) + memtable_expr = ls.memtable(df, columns=column_names, name=table_name) + return memtable_expr + + @register_from_yaml_handler("InMemoryTable") def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: file_path = yaml_dict["file"] @@ -371,13 +430,13 @@ def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: profile_name = getattr(op.source, "profile_name", None) remote_expr_yaml = translate_to_yaml(op.remote_expr, compiler) + schema_id = compiler.schema_registry.register_schema(op.schema) + return freeze( { "op": "RemoteTable", "table": op.name, - "schema": { - name: _translate_type(dtype) for name, dtype in op.schema.items() - }, + "schema_ref": schema_id, "profile": profile_name, "remote_expr": remote_expr_yaml, } @@ -391,7 +450,7 @@ def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: remote_expr_yaml = yaml_dict.get("remote_expr") if profile_name is None: raise ValueError( - "Missing keys in RemoteTable YAML; ensure 'profile', 'table', and 'remote_expr' are present." + "Missing keys in RemoteTable YAML; ensure 'profile_name' are present." ) try: con = compiler.profiles[profile_name] @@ -593,21 +652,25 @@ def _aggregate_to_yaml(op: ops.Aggregate, compiler: Any) -> dict: @register_from_yaml_handler("Aggregate") def _aggregate_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + parent = translate_from_yaml(yaml_dict["parent"], compiler) groups = tuple( translate_from_yaml(group, compiler) for group in yaml_dict.get("by", []) ) - raw_metrics = { + metrics = { name: translate_from_yaml(metric, compiler) for name, metric in yaml_dict.get("metrics", {}).items() } - metrics = raw_metrics - if groups: - return parent.group_by(list(groups)).aggregate(metrics) - else: - return parent.aggregate(metrics) + result = ( + parent.group_by(list(groups)).aggregate(metrics) + if groups + else parent.aggregate(metrics) + ) + return result @translate_to_yaml.register(ops.JoinChain) @@ -773,33 +836,24 @@ def _field_to_yaml(op: ops.Field, compiler: Any) -> dict: "relation": translate_to_yaml(op.rel, compiler), "type": _translate_type(op.dtype), } + if op.args and len(op.args) >= 2 and isinstance(op.args[1], str): underlying_name = op.args[1] if underlying_name != op.name: result["original_name"] = underlying_name + return freeze(result) @register_from_yaml_handler("Field") def field_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - relation = translate_from_yaml(yaml_dict["relation"], compiler) + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + relation = translate_from_yaml(yaml_dict["relation"], compiler) target_name = yaml_dict["name"] source_name = yaml_dict.get("original_name", target_name) - - schema = relation.schema() if callable(relation.schema) else relation.schema - - if source_name not in schema.names: - if target_name in schema.names: - source_name = target_name - else: - columns_formatted = ", ".join(schema.names) - raise IbisTypeError( - f"Column {source_name!r} not found in table. " - f"Existing columns: {columns_formatted}." - ) field = relation[source_name] - if target_name != source_name: field = field.name(target_name) @@ -1301,40 +1355,3 @@ def _type_from_yaml(yaml_dict: dict) -> dt.DataType: nullable=yaml_dict.get("nullable", True), ), } - -# === Helper functions for translating cache storage === - - -def translate_storage(storage, compiler: any) -> dict: - from letsql.common.caching import ParquetStorage, SourceStorage - - if isinstance(storage, ParquetStorage): - return {"type": "ParquetStorage", "path": str(storage.path)} - elif isinstance(storage, SourceStorage): - return { - "type": "SourceStorage", - "source": getattr(storage.source, "profile_name", None), - } - else: - raise NotImplementedError(f"Unknown storage type: {type(storage)}") - - -def load_storage_from_yaml(storage_yaml: dict, compiler: any): - from letsql.expr.relations import ParquetStorage, _SourceStorage - - if storage_yaml["type"] == "ParquetStorage": - default_profile = list(compiler.profiles.values())[0] - return ParquetStorage( - source=default_profile, path=pathlib.Path(storage_yaml["path"]) - ) - elif storage_yaml["type"] == "SourceStorage": - source_profile_name = storage_yaml["source"] - try: - source = compiler.profiles[source_profile_name] - except KeyError: - raise ValueError( - f"Source profile {source_profile_name!r} not found in compiler.profiles" - ) - return _SourceStorage(source=source) - else: - raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index e5e46eea..a791d0f2 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -1,10 +1,11 @@ import base64 from collections.abc import Mapping, Sequence +from typing import Any, Dict import cloudpickle from ibis.common.collections import FrozenOrderedDict -from letsql.common.caching import ParquetStorage, SourceStorage +from letsql.common.caching import SourceStorage def serialize_udf_function(fn: callable) -> str: @@ -129,10 +130,8 @@ def diff_ibis_exprs(expr1, expr2): return diffs -def translate_storage(storage, compiler: any) -> dict: - if isinstance(storage, ParquetStorage): - return {"type": "ParquetStorage", "path": str(storage.path)} - elif isinstance(storage, SourceStorage): +def translate_storage(storage, compiler: Any) -> Dict: + if isinstance(storage, SourceStorage): return { "type": "SourceStorage", "source": getattr(storage.source, "profile_name", None), @@ -141,7 +140,7 @@ def translate_storage(storage, compiler: any) -> dict: raise NotImplementedError(f"Unknown storage type: {type(storage)}") -def load_storage_from_yaml(storage_yaml: dict, compiler: any): +def load_storage_from_yaml(storage_yaml: Dict, compiler: Any): if storage_yaml["type"] == "SourceStorage": source_profile_name = storage_yaml["source"] try: From ee4c12e7894f3a60cfe3ea2a7946eeaec3192cdc Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 16 Feb 2025 18:11:06 -0500 Subject: [PATCH 09/45] feat: add RowNumber support --- .../ibis_yaml/tests/test_window_functions.py | 44 +++++++++++++++++++ python/letsql/ibis_yaml/translate.py | 14 ++++++ 2 files changed, 58 insertions(+) diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/letsql/ibis_yaml/tests/test_window_functions.py index 36b03253..53c6b53c 100644 --- a/python/letsql/ibis_yaml/tests/test_window_functions.py +++ b/python/letsql/ibis_yaml/tests/test_window_functions.py @@ -54,3 +54,47 @@ def test_aggregation_window(compiler, t): assert window_func["end"] == following assert window_func["group_by"][0]["name"] == "a" + + +def test_row_number_simple_roundtrip(compiler, t): + expr = t.select([ibis.row_number().name("row_num")]) + yaml_dict = compiler.compile_to_yaml(expr) + reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + assert expr.equals(reconstructed_expr) + + +def test_row_number_window_roundtrip(compiler, t): + expr = t.select( + [ + ibis.row_number() + .over( + ibis.window( + group_by=[t.a, t.b], + order_by=[t.c.desc(), t.d], + preceding=5, + following=0, + ) + ) + .name("row_num") + ] + ) + yaml_dict = compiler.compile_to_yaml(expr) + reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + assert expr.equals(reconstructed_expr) + + +def test_multiple_rank_expressions_roundtrip(compiler, t): + expr = t.select( + [ + ibis.row_number().over(ibis.window(group_by=t.a)).name("simple_row_num"), + ibis.row_number() + .over(ibis.window(group_by=[t.a, t.b], order_by=t.c.desc())) + .name("ordered_row_num"), + t.c.mean() + .over(ibis.window(preceding=3, following=0, group_by=t.a)) + .name("mean_c"), + ] + ) + yaml_dict = compiler.compile_to_yaml(expr) + reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + assert expr.equals(reconstructed_expr) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 74a97029..0af15a70 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -938,6 +938,20 @@ def _count_distinct_to_yaml(op: ops.CountDistinct, compiler: Any) -> dict: ) +@translate_to_yaml.register(ops.RankBase) +def _rank_base_to_yaml(op: ops.RankBase, compiler: Any) -> dict: + return freeze( + { + "op": type(op).__name__, + } + ) + + +@register_from_yaml_handler("RowNumber") +def _row_number_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + return ibis.row_number() + + @register_from_yaml_handler("CountDistinct") def _count_distinct_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: arg = translate_from_yaml(yaml_dict["args"][0], compiler) From f29a93e4eb3925cdc499ae8d68078651fe2bd931 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 16 Feb 2025 18:34:19 -0500 Subject: [PATCH 10/45] feat: add schema for expr output in yaml --- python/letsql/ibis_yaml/compiler.py | 10 ++++++++++ python/letsql/ibis_yaml/tests/test_compiler.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 47810656..9f41c8cd 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -148,6 +148,8 @@ def compile_to_yaml(self, expr) -> Dict: expr_hash = self.build_manager.get_expr_hash(expr) self.current_path = self.build_manager.get_build_path(expr_hash) expr_yaml = translate_to_yaml(expr, self) + schema_ref = self.get_expr_schema_ref(expr) + expr_yaml = freeze({**dict(expr_yaml), "schema_ref": schema_ref}) schema_definitions = {} for schema_id, schema in self.schema_registry.schemas.items(): @@ -156,3 +158,11 @@ def compile_to_yaml(self, expr) -> Dict: return freeze( {"definitions": {"schemas": schema_definitions}, "expression": expr_yaml} ) + + def get_expr_schema_ref(self, expr: ir.Expr) -> str: + if hasattr(expr, "schema"): + schema = expr.schema() + schema_ref = self.schema_registry.register_schema(schema) + else: + schema_ref = None + return schema_ref diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index cd077563..2f22e550 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -2,6 +2,7 @@ import pathlib import dask +import yaml from ibis.common.collections import FrozenOrderedDict import letsql as ls @@ -123,3 +124,19 @@ def test_compiler_sql(build_dir): ) assert sql_text == expected_result + + +def test_ibis_compiler_expr_schema_ref(t, build_dir): + t = ls.memtable({"a": [0, 1], "b": [0, 1]}) + backend = t._find_backend() + backend.profile_name = "default" + expr = t.filter(t.a == 1).drop("b") + compiler = IbisYamlCompiler(build_dir) + compiler.profiles = {"default": backend} + compiler.compile(expr) + expr_hash = dask.base.tokenize(expr)[:12] + + with open(build_dir / expr_hash / "expr.yaml") as f: + yaml_dict = yaml.safe_load(f) + + assert yaml_dict["expression"]["schema_ref"] From dabd66e5d55032983cb5ae7ae15cd0da126ac2e4 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 17 Feb 2025 09:27:59 -0500 Subject: [PATCH 11/45] refactor: refactor user facing class -> BuildManager - drop IbisYamlCompiler -> YamlExpressionTranslator - drop BuildManager -> ArtifactStore - Add BuildManager as a user-facing class --- python/letsql/ibis_yaml/compiler.py | 144 +++++++++--------- python/letsql/ibis_yaml/sql.py | 1 + python/letsql/ibis_yaml/tests/conftest.py | 4 +- .../letsql/ibis_yaml/tests/test_arithmetic.py | 24 +-- python/letsql/ibis_yaml/tests/test_basic.py | 52 +++---- .../letsql/ibis_yaml/tests/test_compiler.py | 40 +++-- .../letsql/ibis_yaml/tests/test_join_chain.py | 4 +- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 31 ++-- .../tests/test_operations_boolean.py | 36 ++--- .../ibis_yaml/tests/test_operations_cast.py | 16 +- .../tests/test_operations_datetime.py | 44 +++--- .../letsql/ibis_yaml/tests/test_relations.py | 24 +-- .../letsql/ibis_yaml/tests/test_selection.py | 4 +- .../letsql/ibis_yaml/tests/test_string_ops.py | 10 +- .../letsql/ibis_yaml/tests/test_subquery.py | 12 +- python/letsql/ibis_yaml/tests/test_tpch.py | 4 +- python/letsql/ibis_yaml/tests/test_udf.py | 8 +- .../ibis_yaml/tests/test_window_functions.py | 18 +-- 18 files changed, 233 insertions(+), 243 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 9f41c8cd..06f866ac 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -27,7 +27,7 @@ def represent_frozenordereddict(self, data): ) -class StorageHandler: +class ArtifactStore: def __init__(self, root_path: pathlib.Path): self.root_path = ( Path(root_path) if not isinstance(root_path, Path) else root_path @@ -45,7 +45,6 @@ def ensure_dir(self, *parts) -> pathlib.Path: def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: path = self.get_path(*path_parts) path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w") as f: yaml.dump( data, @@ -60,14 +59,12 @@ def read_yaml(self, *path_parts) -> Dict[str, Any]: path = self.get_path(*path_parts) if not path.exists(): raise FileNotFoundError(f"File not found: {path}") - with path.open("r") as f: return yaml.safe_load(f) def write_text(self, content: str, *path_parts) -> pathlib.Path: path = self.get_path(*path_parts) path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w") as f: f.write(content) return path @@ -76,93 +73,90 @@ def read_text(self, *path_parts) -> str: path = self.get_path(*path_parts) if not path.exists(): raise FileNotFoundError(f"File not found: {path}") - with path.open("r") as f: return f.read() def exists(self, *path_parts) -> bool: return self.get_path(*path_parts).exists() - -class BuildManager: - def __init__(self, storage_path: pathlib.Path): - self.storage = StorageHandler(storage_path) - def get_expr_hash(self, expr) -> str: expr_hash = dask.base.tokenize(expr) return expr_hash[:12] # TODO: make length of hash as a config - def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash) -> pathlib.Path: - filename = ( - "sql.yaml" - if isinstance(yaml_dict, dict) and "queries" in yaml_dict - else "expr.yaml" - ) - return self.storage.write_yaml(yaml_dict, expr_hash, filename) + def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash, filename) -> pathlib.Path: + return self.write_yaml(yaml_dict, expr_hash, filename) - def load_yaml(self, expr_hash: str) -> Dict[str, Any]: - return self.storage.read_yaml(expr_hash, "expr.yaml") + def load_yaml(self, expr_hash: str, filename) -> Dict[str, Any]: + return self.read_yaml(expr_hash, filename) def get_build_path(self, expr_hash: str) -> pathlib.Path: - return self.storage.ensure_dir(expr_hash) - - -class IbisYamlCompiler: - def __init__(self, build_dir, build_manager=BuildManager): - self.build_manager = build_manager(build_dir) - self.schema_registry = SchemaRegistry() - self.current_path = None - - def compile(self, expr): - yaml_dict = self.compile_to_yaml(expr) - expr_hash = self.build_manager.get_expr_hash(expr) - self.curent_path = self.build_manager.get_build_path(expr_hash) - self.build_manager.save_yaml(yaml_dict, expr_hash) - plans = generate_sql_plans(expr) - self.build_manager.save_yaml(plans, expr_hash) - - def from_hash(self, expr_hash) -> ir.Expr: - yaml_dict = self.build_manager.load_yaml(expr_hash) - - def convert_to_frozen(d): - if isinstance(d, dict): - items = [] - for k, v in d.items(): - converted_v = convert_to_frozen(v) - if isinstance(converted_v, list): - converted_v = tuple(converted_v) - items.append((k, converted_v)) - return FrozenOrderedDict(items) - elif isinstance(d, list): - return [convert_to_frozen(x) for x in d] - return d - - yaml_dict = convert_to_frozen(yaml_dict) - return self.compile_from_yaml(yaml_dict) - - def compile_from_yaml(self, yaml_dict) -> ir.Expr: - self.definitions = yaml_dict.get("definitions", {}) - return translate_from_yaml(yaml_dict["expression"], self) - - def compile_to_yaml(self, expr) -> Dict: - expr_hash = self.build_manager.get_expr_hash(expr) - self.current_path = self.build_manager.get_build_path(expr_hash) - expr_yaml = translate_to_yaml(expr, self) - schema_ref = self.get_expr_schema_ref(expr) - expr_yaml = freeze({**dict(expr_yaml), "schema_ref": schema_ref}) - - schema_definitions = {} - for schema_id, schema in self.schema_registry.schemas.items(): - schema_definitions[schema_id] = schema + return self.ensure_dir(expr_hash) + + +class YamlExpressionTranslator: + def __init__( + self, + schema_registry: SchemaRegistry = None, + profiles: Dict = None, + current_path: Path = None, + ): + self.schema_registry = schema_registry or SchemaRegistry() + self.definitions = {} + self.profiles = profiles or {} + self.current_path = current_path + + def to_yaml(self, expr: ir.Expr) -> Dict[str, Any]: + schema_ref = self._register_expr_schema(expr) + expr_dict = translate_to_yaml(expr, self) + expr_dict = freeze({**dict(expr_dict), "schema_ref": schema_ref}) return freeze( - {"definitions": {"schemas": schema_definitions}, "expression": expr_yaml} + { + "definitions": {"schemas": self.schema_registry.schemas}, + "expression": expr_dict, + } ) - def get_expr_schema_ref(self, expr: ir.Expr) -> str: + def from_yaml(self, yaml_dict: Dict[str, Any]) -> ir.Expr: + self.definitions = yaml_dict.get("definitions", {}) + expr_dict = freeze(yaml_dict["expression"]) + return translate_from_yaml(expr_dict, self) + + def _register_expr_schema(self, expr: ir.Expr) -> str: if hasattr(expr, "schema"): schema = expr.schema() - schema_ref = self.schema_registry.register_schema(schema) - else: - schema_ref = None - return schema_ref + return self.schema_registry.register_schema(schema) + return None + + +class BuildManager: + def __init__(self, build_dir: pathlib.Path): + self.artifact_store = ArtifactStore(build_dir) + self.profiles = {} + + def compile_expr(self, expr: ir.Expr) -> None: + expr_hash = self.artifact_store.get_expr_hash(expr) + current_path = self.artifact_store.get_build_path(expr_hash) + + translator = YamlExpressionTranslator( + profiles=self.profiles, current_path=current_path + ) + # metadata.yaml (uv.lock, git commit version, version==xorq_internal_version, user, hostname, ip_address(host ip)) + yaml_dict = translator.to_yaml(expr) + self.artifact_store.save_yaml(yaml_dict, expr_hash, "expr.yaml") + + sql_plans = generate_sql_plans(expr) + self.artifact_store.save_yaml(sql_plans, expr_hash, "sql.yaml") + + def load_expr(self, expr_hash: str) -> ir.Expr: + build_path = self.artifact_store.get_build_path(expr_hash) + translator = YamlExpressionTranslator( + current_path=build_path, profiles=self.profiles + ) + + yaml_dict = self.artifact_store.load_yaml(expr_hash, "expr.yaml") + return translator.from_yaml(yaml_dict) + + # TODO: maybe change name + def load_sql_plans(self, expr_hash: str) -> Dict[str, Any]: + return self.artifact_store.load_yaml(expr_hash, "sql.yaml") diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index d25b94b7..6f7e19f4 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -63,6 +63,7 @@ def traverse(node): return remote_tables +# TODO: rename to sqls def generate_sql_plans(expr: ir.Expr) -> SQLPlans: remote_tables = find_remote_tables(expr.op()) diff --git a/python/letsql/ibis_yaml/tests/conftest.py b/python/letsql/ibis_yaml/tests/conftest.py index f2d79d6d..8c56c6ff 100644 --- a/python/letsql/ibis_yaml/tests/conftest.py +++ b/python/letsql/ibis_yaml/tests/conftest.py @@ -787,6 +787,6 @@ def build_dir(tmp_path_factory): @pytest.fixture def compiler(build_dir): - from letsql.ibis_yaml.compiler import IbisYamlCompiler + from letsql.ibis_yaml.compiler import YamlExpressionTranslator - return IbisYamlCompiler(build_dir) + return YamlExpressionTranslator() diff --git a/python/letsql/ibis_yaml/tests/test_arithmetic.py b/python/letsql/ibis_yaml/tests/test_arithmetic.py index 2fc6a11d..823fbf0f 100644 --- a/python/letsql/ibis_yaml/tests/test_arithmetic.py +++ b/python/letsql/ibis_yaml/tests/test_arithmetic.py @@ -6,7 +6,7 @@ def test_add(compiler): lit2 = ibis.literal(3) expr = lit1 + lit2 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Add" assert expression["args"][0]["op"] == "Literal" @@ -15,7 +15,7 @@ def test_add(compiler): assert expression["args"][1]["value"] == 3 assert expression["type"] == {"name": "Int8", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -24,7 +24,7 @@ def test_subtract(compiler): lit2 = ibis.literal(3) expr = lit1 - lit2 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Subtract" assert expression["args"][0]["op"] == "Literal" @@ -33,7 +33,7 @@ def test_subtract(compiler): assert expression["args"][1]["value"] == 3 assert expression["type"] == {"name": "Int8", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -42,7 +42,7 @@ def test_multiply(compiler): lit2 = ibis.literal(3) expr = lit1 * lit2 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Multiply" assert expression["args"][0]["op"] == "Literal" @@ -51,7 +51,7 @@ def test_multiply(compiler): assert expression["args"][1]["value"] == 3 assert expression["type"] == {"name": "Int8", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -60,7 +60,7 @@ def test_divide(compiler): lit2 = ibis.literal(2.0) expr = lit1 / lit2 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Divide" assert expression["args"][0]["op"] == "Literal" @@ -69,7 +69,7 @@ def test_divide(compiler): assert expression["args"][1]["value"] == 2.0 assert expression["type"] == {"name": "Float64", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -78,12 +78,12 @@ def test_mixed_arithmetic(compiler): f = ibis.literal(2.5) expr = i * f - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Multiply" assert expression["type"] == {"name": "Float64", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -93,11 +93,11 @@ def test_complex_arithmetic(compiler): c = ibis.literal(2.0) expr = (a + b) * c - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Multiply" assert expression["args"][0]["op"] == "Add" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_basic.py b/python/letsql/ibis_yaml/tests/test_basic.py index 85f20294..11f33ee4 100644 --- a/python/letsql/ibis_yaml/tests/test_basic.py +++ b/python/letsql/ibis_yaml/tests/test_basic.py @@ -5,52 +5,52 @@ def test_unbound_table(t, compiler): - yaml_dict = compiler.compile_to_yaml(t) + yaml_dict = compiler.to_yaml(t) expression = yaml_dict["expression"] assert expression["op"] == "UnboundTable" assert expression["name"] == "test_table" assert expression["schema_ref"] - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.schema() == t.schema() assert roundtrip_expr.op().name == t.op().name def test_field(t, compiler): expr = t.a - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Field" assert expression["name"] == "a" assert expression["type"] == {"name": "Int64", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) assert roundtrip_expr.get_name() == expr.get_name() def test_literal(compiler): lit = ibis.literal(42) - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["value"] == 42 assert expression["type"] == {"name": "Int8", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) def test_binary_op(t, compiler): expr = t.a + 1 - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Add" assert expression["args"][0]["op"] == "Field" assert expression["args"][1]["op"] == "Literal" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -64,13 +64,13 @@ def test_primitive_types(compiler): (ibis.literal(None), "Null"), ] for lit, expected_type in primitives: - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["type"]["name"] == expected_type - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) assert roundtrip_expr.type().name == lit.type().name @@ -85,12 +85,12 @@ def test_temporal_types(compiler): (ibis.literal(time), "Time"), ] for lit, expected_type in temporals: - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["type"]["name"] == expected_type - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) assert roundtrip_expr.type().name == lit.type().name @@ -98,41 +98,41 @@ def test_temporal_types(compiler): def test_decimal_type(compiler): dec = decimal.Decimal("123.45") lit = ibis.literal(dec) - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["type"]["name"] == "Decimal" assert expression["type"]["nullable"] - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) assert roundtrip_expr.type().name == lit.type().name def test_array_type(compiler): lit = ibis.literal([1, 2, 3]) - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["type"]["name"] == "Array" assert expression["type"]["value_type"]["name"] == "Int8" assert expression["value"] == (1, 2, 3) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) assert roundtrip_expr.type().value_type == lit.type().value_type def test_map_type(compiler): lit = ibis.literal({"a": 1, "b": 2}) - yaml_dict = compiler.compile_to_yaml(lit) + yaml_dict = compiler.to_yaml(lit) expression = yaml_dict["expression"] assert expression["op"] == "Literal" assert expression["type"]["name"] == "Map" assert expression["type"]["key_type"]["name"] == "String" assert expression["type"]["value_type"]["name"] == "Int8" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(lit) assert roundtrip_expr.type().key_type == lit.type().key_type assert roundtrip_expr.type().value_type == lit.type().value_type @@ -140,28 +140,28 @@ def test_map_type(compiler): def test_complex_expression_roundtrip(t, compiler): expr = (t.a + 1).abs() * 2 - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_window_function_roundtrip(t, compiler): expr = t.a.sum().over(ibis.window(group_by=t.a)) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_join_roundtrip(t, compiler): t2 = ibis.table({"b": "int64"}, name="test_table_2") expr = t.join(t2, t.a == t2.b) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.schema() == expr.schema() def test_aggregation_roundtrip(t, compiler): expr = t.group_by(t.a).aggregate(count=t.a.count()) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.schema() == expr.schema() diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index 2f22e550..b065ce97 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -6,28 +6,26 @@ from ibis.common.collections import FrozenOrderedDict import letsql as ls -from letsql.ibis_yaml.compiler import BuildManager, IbisYamlCompiler +from letsql.ibis_yaml.compiler import ArtifactStore, BuildManager def test_build_manager_expr_hash(t, build_dir): expected = "c6527994ad9a" - build_manager = BuildManager(build_dir) + build_manager = ArtifactStore(build_dir) result = build_manager.get_expr_hash(t) assert expected == result def test_build_manager_roundtrip(t, build_dir): - build_manager = BuildManager(build_dir) + build_manager = ArtifactStore(build_dir) expr_hash = "c6527994ad9a" yaml_dict = {"a": "string"} - build_manager.save_yaml(yaml_dict, expr_hash) + build_manager.save_yaml(yaml_dict, expr_hash, "expr.yaml") with open(build_dir / expr_hash / "expr.yaml") as f: out = f.read() assert out == "a: string\n" - result = build_manager.load_yaml(expr_hash) - - # assert os.path.exists(build_dir/ expr_hash / "sql.yaml") + result = build_manager.load_yaml(expr_hash, "expr.yaml") assert result == yaml_dict @@ -35,7 +33,7 @@ def test_build_manager_paths(t, build_dir): new_path = build_dir / "new_path" assert not os.path.exists(new_path) - build_manager = BuildManager(new_path) + build_manager = ArtifactStore(new_path) assert os.path.exists(new_path) build_manager.get_build_path("hash") @@ -43,7 +41,7 @@ def test_build_manager_paths(t, build_dir): def test_clean_frozen_dict_yaml(build_dir): - build_manager = BuildManager(build_dir) + build_manager = ArtifactStore(build_dir) data = FrozenOrderedDict( {"string": "text", "integer": 42, "float": 3.14, "boolean": True, "none": None} ) @@ -54,7 +52,7 @@ def test_clean_frozen_dict_yaml(build_dir): boolean: true none: null """ - out_path = build_manager.save_yaml(data, "hash") + out_path = build_manager.save_yaml(data, "hash", "expr.yaml") result = out_path.read_text() assert expected_yaml == result @@ -65,12 +63,12 @@ def test_ibis_compiler(t, build_dir): backend = t._find_backend() backend.profile_name = "default" expr = t.filter(t.a == 1).drop("b") - compiler = IbisYamlCompiler(build_dir) + compiler = BuildManager(build_dir) compiler.profiles = {"default": backend} - compiler.compile(expr) + compiler.compile_expr(expr) expr_hash = dask.base.tokenize(expr)[:12] - roundtrip_expr = compiler.from_hash(expr_hash) + roundtrip_expr = compiler.load_expr(expr_hash) assert expr.execute().equals(roundtrip_expr.execute()) @@ -84,11 +82,11 @@ def test_ibis_compiler_parquet_reader(t, build_dir): ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") - compiler = IbisYamlCompiler(build_dir) + compiler = BuildManager(build_dir) compiler.profiles = {"default": backend} - compiler.compile(expr) + compiler.compile_expr(expr) expr_hash = "5ebaf6a7a02d" - roundtrip_expr = compiler.from_hash(expr_hash) + roundtrip_expr = compiler.load_expr(expr_hash) assert expr.execute().equals(roundtrip_expr.execute()) @@ -102,11 +100,11 @@ def test_compiler_sql(build_dir): ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") - compiler = IbisYamlCompiler(build_dir) + compiler = BuildManager(build_dir) compiler.profiles = {"default": backend} - compiler.compile(expr) + compiler.compile_expr(expr) expr_hash = "5ebaf6a7a02d" - _roundtrip_expr = compiler.from_hash(expr_hash) + _roundtrip_expr = compiler.load_expr(expr_hash) assert os.path.exists(build_dir / expr_hash / "sql.yaml") @@ -131,9 +129,9 @@ def test_ibis_compiler_expr_schema_ref(t, build_dir): backend = t._find_backend() backend.profile_name = "default" expr = t.filter(t.a == 1).drop("b") - compiler = IbisYamlCompiler(build_dir) + compiler = BuildManager(build_dir) compiler.profiles = {"default": backend} - compiler.compile(expr) + compiler.compile_expr(expr) expr_hash = dask.base.tokenize(expr)[:12] with open(build_dir / expr_hash / "expr.yaml") as f: diff --git a/python/letsql/ibis_yaml/tests/test_join_chain.py b/python/letsql/ibis_yaml/tests/test_join_chain.py index 11fc10f5..88e5e4e1 100644 --- a/python/letsql/ibis_yaml/tests/test_join_chain.py +++ b/python/letsql/ibis_yaml/tests/test_join_chain.py @@ -89,8 +89,8 @@ def test_minimal_joinchain_self_reference( ) ) - yaml_dict = compiler.compile_to_yaml(q) - q_roundtrip = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(q) + q_roundtrip = compiler.from_yaml(yaml_dict) try: _ = q_roundtrip["cust_nation"] diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index 7b826322..62a7ffbd 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -2,7 +2,7 @@ import letsql as ls from letsql.expr.relations import into_backend -from letsql.ibis_yaml.compiler import IbisYamlCompiler +from letsql.ibis_yaml.compiler import YamlExpressionTranslator @pytest.fixture(scope="session") @@ -39,17 +39,16 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): profiles = {"my_duckdb": con} - table_expr = con.table("mytable") # DatabaseTable op + table_expr = con.table("mytable") expr1 = table_expr.mutate(new_val=(table_expr.val + "_extra")) - compiler = IbisYamlCompiler(build_dir) - compiler.profiles = profiles + compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) - yaml_dict = compiler.compile_to_yaml(expr1) + yaml_dict = compiler.to_yaml(expr1) print("Serialized YAML:\n", yaml_dict) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) df_original = expr1.execute() df_roundtrip = roundtrip_expr.execute() @@ -65,11 +64,10 @@ def test_memtable(prepare_duckdb_con, build_dir): profiles = {"default-duckdb": backend} - compiler = IbisYamlCompiler(build_dir) - compiler.profiles = profiles + compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) expr.equals(roundtrip_expr) @@ -96,11 +94,10 @@ def test_into_backend(prepare_duckdb_con, build_dir): "default-datafusion": con3, } - compiler = IbisYamlCompiler(build_dir) - compiler.profiles = profiles + compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) @@ -115,10 +112,10 @@ def test_memtable_cache(prepare_duckdb_con, build_dir): profiles = {"default-duckdb": backend, "default-let": backend1} - compiler = IbisYamlCompiler(build_dir) + compiler = YamlExpressionTranslator(profiles=profiles, current_path=build_dir) compiler.profiles = profiles - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) diff --git a/python/letsql/ibis_yaml/tests/test_operations_boolean.py b/python/letsql/ibis_yaml/tests/test_operations_boolean.py index 4d0171e6..922519f7 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_boolean.py +++ b/python/letsql/ibis_yaml/tests/test_operations_boolean.py @@ -5,13 +5,13 @@ def test_equals(compiler): a = ibis.literal(5) b = ibis.literal(5) expr = a == b - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Equals" assert expression["args"][0]["value"] == 5 assert expression["args"][1]["value"] == 5 assert expression["type"] == {"name": "Boolean", "nullable": True} - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -19,12 +19,12 @@ def test_not_equals(compiler): a = ibis.literal(5) b = ibis.literal(3) expr = a != b - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "NotEquals" assert expression["args"][0]["value"] == 5 assert expression["args"][1]["value"] == 3 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -32,12 +32,12 @@ def test_greater_than(compiler): a = ibis.literal(5) b = ibis.literal(3) expr = a > b - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Greater" assert expression["args"][0]["value"] == 5 assert expression["args"][1]["value"] == 3 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -45,12 +45,12 @@ def test_less_than(compiler): a = ibis.literal(3) b = ibis.literal(5) expr = a < b - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Less" assert expression["args"][0]["value"] == 3 assert expression["args"][1]["value"] == 5 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -60,54 +60,54 @@ def test_and_or(compiler): c = ibis.literal(10) expr_and = (a > b) & (a < c) - yaml_dict = compiler.compile_to_yaml(expr_and) + yaml_dict = compiler.to_yaml(expr_and) expression = yaml_dict["expression"] assert expression["op"] == "And" assert expression["args"][0]["op"] == "Greater" assert expression["args"][1]["op"] == "Less" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr_and) expr_or = (a > b) | (a < c) - yaml_dict = compiler.compile_to_yaml(expr_or) + yaml_dict = compiler.to_yaml(expr_or) expression = yaml_dict["expression"] assert expression["op"] == "Or" assert expression["args"][0]["op"] == "Greater" assert expression["args"][1]["op"] == "Less" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr_or) def test_not(compiler): a = ibis.literal(True) expr = ~a - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Not" assert expression["args"][0]["value"] - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_is_null(compiler): a = ibis.literal(None) expr = a.isnull() - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "IsNull" assert expression["args"][0]["value"] is None - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_between(compiler): a = ibis.literal(5) expr = a.between(3, 7) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Between" assert expression["args"][0]["value"] == 5 assert expression["args"][1]["value"] == 3 assert expression["args"][2]["value"] == 7 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_cast.py b/python/letsql/ibis_yaml/tests/test_operations_cast.py index 38380739..2f6b5264 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_cast.py +++ b/python/letsql/ibis_yaml/tests/test_operations_cast.py @@ -3,7 +3,7 @@ def test_explicit_cast(compiler): expr = ibis.literal(42).cast("float64") - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Cast" @@ -11,7 +11,7 @@ def test_explicit_cast(compiler): assert expression["args"][0]["value"] == 42 assert expression["type"]["name"] == "Float64" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -19,7 +19,7 @@ def test_implicit_cast(compiler): i = ibis.literal(1) f = ibis.literal(2.5) expr = i + f - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Add" @@ -27,31 +27,31 @@ def test_implicit_cast(compiler): assert expression["args"][1]["type"]["name"] == "Float64" assert expression["type"]["name"] == "Float64" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_string_cast(compiler): expr = ibis.literal("42").cast("int64") - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Cast" assert expression["args"][0]["value"] == "42" assert expression["type"]["name"] == "Int64" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_timestamp_cast(compiler): expr = ibis.literal("2024-01-01").cast("timestamp") - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Cast" assert expression["args"][0]["value"] == "2024-01-01" assert expression["type"]["name"] == "Timestamp" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_operations_datetime.py b/python/letsql/ibis_yaml/tests/test_operations_datetime.py index ded9fc91..68ae1f8f 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_datetime.py +++ b/python/letsql/ibis_yaml/tests/test_operations_datetime.py @@ -11,26 +11,26 @@ def test_date_extract(compiler): dt_expr = ibis.literal(datetime(2024, 3, 14, 15, 9, 26)) year = dt_expr.year() - year_yaml = compiler.compile_to_yaml(year) + year_yaml = compiler.to_yaml(year) expression = year_yaml["expression"] assert expression["op"] == "ExtractYear" assert expression["args"][0]["value"] == "2024-03-14T15:09:26" assert expression["type"]["name"] == "Int32" - roundtrip_year = compiler.compile_from_yaml(year_yaml) + roundtrip_year = compiler.from_yaml(year_yaml) assert roundtrip_year.equals(year) month = dt_expr.month() - month_yaml = compiler.compile_to_yaml(month) + month_yaml = compiler.to_yaml(month) expression = month_yaml["expression"] assert expression["op"] == "ExtractMonth" - roundtrip_month = compiler.compile_from_yaml(month_yaml) + roundtrip_month = compiler.from_yaml(month_yaml) assert roundtrip_month.equals(month) day = dt_expr.day() - day_yaml = compiler.compile_to_yaml(day) + day_yaml = compiler.to_yaml(day) expression = day_yaml["expression"] assert expression["op"] == "ExtractDay" - roundtrip_day = compiler.compile_from_yaml(day_yaml) + roundtrip_day = compiler.from_yaml(day_yaml) assert roundtrip_day.equals(day) @@ -38,26 +38,26 @@ def test_time_extract(compiler): dt_expr = ibis.literal(datetime(2024, 3, 14, 15, 9, 26)) hour = dt_expr.hour() - hour_yaml = compiler.compile_to_yaml(hour) + hour_yaml = compiler.to_yaml(hour) hour_expression = hour_yaml["expression"] assert hour_expression["op"] == "ExtractHour" assert hour_expression["args"][0]["value"] == "2024-03-14T15:09:26" assert hour_expression["type"]["name"] == "Int32" - roundtrip_hour = compiler.compile_from_yaml(hour_yaml) + roundtrip_hour = compiler.from_yaml(hour_yaml) assert roundtrip_hour.equals(hour) minute = dt_expr.minute() - minute_yaml = compiler.compile_to_yaml(minute) + minute_yaml = compiler.to_yaml(minute) minute_expression = minute_yaml["expression"] assert minute_expression["op"] == "ExtractMinute" - roundtrip_minute = compiler.compile_from_yaml(minute_yaml) + roundtrip_minute = compiler.from_yaml(minute_yaml) assert roundtrip_minute.equals(minute) second = dt_expr.second() - second_yaml = compiler.compile_to_yaml(second) + second_yaml = compiler.to_yaml(second) second_expression = second_yaml["expression"] assert second_expression["op"] == "ExtractSecond" - roundtrip_second = compiler.compile_from_yaml(second_yaml) + roundtrip_second = compiler.from_yaml(second_yaml) assert roundtrip_second.equals(second) @@ -66,21 +66,21 @@ def test_timestamp_arithmetic(compiler): delta = ibis.interval(days=1) plus_day = ts + delta - yaml_dict = compiler.compile_to_yaml(plus_day) + yaml_dict = compiler.to_yaml(plus_day) expression = yaml_dict["expression"] assert expression["op"] == "TimestampAdd" assert expression["type"]["name"] == "Timestamp" assert expression["args"][1]["type"]["name"] == "Interval" - roundtrip_plus = compiler.compile_from_yaml(yaml_dict) + roundtrip_plus = compiler.from_yaml(yaml_dict) assert roundtrip_plus.equals(plus_day) minus_day = ts - delta - yaml_dict = compiler.compile_to_yaml(minus_day) + yaml_dict = compiler.to_yaml(minus_day) expression = yaml_dict["expression"] assert expression["op"] == "TimestampSub" assert expression["type"]["name"] == "Timestamp" assert expression["args"][1]["type"]["name"] == "Interval" - roundtrip_minus = compiler.compile_from_yaml(yaml_dict) + roundtrip_minus = compiler.from_yaml(yaml_dict) assert roundtrip_minus.equals(minus_day) @@ -88,29 +88,29 @@ def test_timestamp_diff(compiler): ts1 = ibis.literal(datetime(2024, 3, 14)) ts2 = ibis.literal(datetime(2024, 3, 15)) diff = ts2 - ts1 - yaml_dict = compiler.compile_to_yaml(diff) + yaml_dict = compiler.to_yaml(diff) expression = yaml_dict["expression"] assert expression["op"] == "TimestampDiff" assert expression["type"]["name"] == "Interval" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(diff) def test_temporal_unit_yaml(compiler): interval_date = ibis.literal(5, type=dt.Interval(unit=tm.DateUnit("D"))) - yaml_date = compiler.compile_to_yaml(interval_date) + yaml_date = compiler.to_yaml(interval_date) expression_date = yaml_date["expression"] assert expression_date["type"]["name"] == "Interval" assert expression_date["type"]["unit"]["name"] == "DateUnit" assert expression_date["type"]["unit"]["value"] == "D" - roundtrip_date = compiler.compile_from_yaml(yaml_date) + roundtrip_date = compiler.from_yaml(yaml_date) assert roundtrip_date.equals(interval_date) interval_time = ibis.literal(10, type=dt.Interval(unit=tm.TimeUnit("h"))) - yaml_time = compiler.compile_to_yaml(interval_time) + yaml_time = compiler.to_yaml(interval_time) expression_time = yaml_time["expression"] assert expression_time["type"]["name"] == "Interval" assert expression_time["type"]["unit"]["name"] == "TimeUnit" assert expression_time["type"]["unit"]["value"] == "h" - roundtrip_time = compiler.compile_from_yaml(yaml_time) + roundtrip_time = compiler.from_yaml(yaml_time) assert roundtrip_time.equals(interval_time) diff --git a/python/letsql/ibis_yaml/tests/test_relations.py b/python/letsql/ibis_yaml/tests/test_relations.py index 19fe9b57..a028432f 100644 --- a/python/letsql/ibis_yaml/tests/test_relations.py +++ b/python/letsql/ibis_yaml/tests/test_relations.py @@ -3,7 +3,7 @@ def test_filter(compiler, t): expr = t.filter(t.a > 0) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] # Original assertions @@ -12,13 +12,13 @@ def test_filter(compiler, t): assert expression["parent"]["op"] == "UnboundTable" # Roundtrip test: compile from YAML and verify equality - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_projection(compiler, t): expr = t.select(["a", "b"]) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] # Original assertions @@ -27,13 +27,13 @@ def test_projection(compiler, t): assert set(expression["values"]) == {"a", "b"} # Roundtrip test - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_aggregation(compiler, t): expr = t.group_by("a").aggregate(avg_c=t.c.mean()) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Aggregate" @@ -41,7 +41,7 @@ def test_aggregation(compiler, t): assert expression["metrics"]["avg_c"]["op"] == "Mean" # Roundtrip test - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -49,7 +49,7 @@ def test_join(compiler): t1 = ibis.table(dict(a="int", b="string"), name="t1") t2 = ibis.table(dict(b="string", c="float"), name="t2") expr = t1.join(t2, t1.b == t2.b) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "JoinChain" @@ -57,28 +57,28 @@ def test_join(compiler): assert expression["rest"][0]["how"] == "inner" # Roundtrip test - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_order_by(compiler, t): expr = t.order_by(["a", "b"]) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Sort" assert len(expression["keys"]) == 2 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) def test_limit(compiler, t): expr = t.limit(10) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Limit" assert expression["n"] == 10 - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_selection.py b/python/letsql/ibis_yaml/tests/test_selection.py index 486bb859..567cf145 100644 --- a/python/letsql/ibis_yaml/tests/test_selection.py +++ b/python/letsql/ibis_yaml/tests/test_selection.py @@ -8,6 +8,6 @@ def test_selection_on_view(compiler): q = q.select({"alias_name": T_view.name}) q = q.filter(q.alias_name == "X") - yaml_dict = compiler.compile_to_yaml(q) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(q) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(q) diff --git a/python/letsql/ibis_yaml/tests/test_string_ops.py b/python/letsql/ibis_yaml/tests/test_string_ops.py index bbfdc6f3..f4abb7d0 100644 --- a/python/letsql/ibis_yaml/tests/test_string_ops.py +++ b/python/letsql/ibis_yaml/tests/test_string_ops.py @@ -5,7 +5,7 @@ def test_string_concat(compiler): s1 = ibis.literal("hello") s2 = ibis.literal("world") expr = s1 + s2 - yaml_dict = compiler.compile_to_yaml(expr)["expression"] + yaml_dict = compiler.to_yaml(expr)["expression"] assert yaml_dict["op"] == "StringConcat" assert yaml_dict["args"][0]["value"] == "hello" @@ -18,11 +18,11 @@ def test_string_upper_lower(compiler): upper_expr = s.upper() lower_expr = s.lower() - upper_yaml = compiler.compile_to_yaml(upper_expr)["expression"] + upper_yaml = compiler.to_yaml(upper_expr)["expression"] assert upper_yaml["op"] == "Uppercase" assert upper_yaml["args"][0]["value"] == "Hello" - lower_yaml = compiler.compile_to_yaml(lower_expr)["expression"] + lower_yaml = compiler.to_yaml(lower_expr)["expression"] assert lower_yaml["op"] == "Lowercase" assert lower_yaml["args"][0]["value"] == "Hello" @@ -30,7 +30,7 @@ def test_string_upper_lower(compiler): def test_string_length(compiler): s = ibis.literal("hello") expr = s.length() - yaml_dict = compiler.compile_to_yaml(expr)["expression"] + yaml_dict = compiler.to_yaml(expr)["expression"] assert yaml_dict["op"] == "StringLength" assert yaml_dict["args"][0]["value"] == "hello" @@ -40,7 +40,7 @@ def test_string_length(compiler): def test_string_substring(compiler): s = ibis.literal("hello world") expr = s.substr(0, 5) - yaml_dict = compiler.compile_to_yaml(expr)["expression"] + yaml_dict = compiler.to_yaml(expr)["expression"] assert yaml_dict["op"] == "Substring" assert yaml_dict["args"][0]["value"] == "hello world" diff --git a/python/letsql/ibis_yaml/tests/test_subquery.py b/python/letsql/ibis_yaml/tests/test_subquery.py index d9056aa2..6d27393f 100644 --- a/python/letsql/ibis_yaml/tests/test_subquery.py +++ b/python/letsql/ibis_yaml/tests/test_subquery.py @@ -4,13 +4,13 @@ def test_scalar_subquery(compiler, t): expr = ops.ScalarSubquery(t.c.mean().as_table()).to_expr() - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "ScalarSubquery" assert expression["args"][0]["op"] == "Aggregate" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -20,13 +20,13 @@ def test_exists_subquery(compiler): filtered = t2.filter(t2.a == t1.a) expr = ops.ExistsSubquery(filtered).to_expr() - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "ExistsSubquery" assert expression["rel"]["op"] == "Filter" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) @@ -35,11 +35,11 @@ def test_in_subquery(compiler): t2 = ibis.table(dict(a="int", c="float"), name="t2") expr = ops.InSubquery(t1.select("a"), t2.a).to_expr() - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "InSubquery" assert expression["type"]["name"] == "Boolean" - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_tpch.py b/python/letsql/ibis_yaml/tests/test_tpch.py index a8e4d725..e8c6c216 100644 --- a/python/letsql/ibis_yaml/tests/test_tpch.py +++ b/python/letsql/ibis_yaml/tests/test_tpch.py @@ -31,8 +31,8 @@ def test_yaml_roundtrip(fixture_name, compiler, request): query = request.getfixturevalue(fixture_name) - yaml_dict = compiler.compile_to_yaml(query) - roundtrip_query = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(query) + roundtrip_query = compiler.from_yaml(yaml_dict) assert roundtrip_query.equals(query), ( f"Roundtrip expression for {fixture_name} does not match the original." diff --git a/python/letsql/ibis_yaml/tests/test_udf.py b/python/letsql/ibis_yaml/tests/test_udf.py index 52696d3e..b5e3b5bb 100644 --- a/python/letsql/ibis_yaml/tests/test_udf.py +++ b/python/letsql/ibis_yaml/tests/test_udf.py @@ -13,8 +13,8 @@ def add_one(x: int) -> int: return x + 1 expr = t.mutate(new=add_one(t.a)) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) original_mutation = expr.op() roundtrip_mutation = roundtrip_expr.op() @@ -50,8 +50,8 @@ def add_one(x: int) -> int: pass expr = t.mutate(new=add_one(t.a)) - yaml_dict = compiler.compile_to_yaml(expr) - roundtrip_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) print(f"Original {expr}") print(f"Roundtrip {roundtrip_expr}") letsql.ibis_yaml.utils.diff_ibis_exprs(expr, roundtrip_expr) diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/letsql/ibis_yaml/tests/test_window_functions.py index 53c6b53c..f0813f14 100644 --- a/python/letsql/ibis_yaml/tests/test_window_functions.py +++ b/python/letsql/ibis_yaml/tests/test_window_functions.py @@ -10,9 +10,9 @@ def test_window_function_roundtrip(compiler, t): ] ) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) - reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + reconstructed_expr = compiler.from_yaml(yaml_dict) assert expr.equals(reconstructed_expr) @@ -36,7 +36,7 @@ def test_aggregation_window(compiler, t): ] ) - yaml_dict = compiler.compile_to_yaml(expr) + yaml_dict = compiler.to_yaml(expr) expression = yaml_dict["expression"] assert expression["op"] == "Project" window_func = expression["values"]["mean_c"] @@ -58,8 +58,8 @@ def test_aggregation_window(compiler, t): def test_row_number_simple_roundtrip(compiler, t): expr = t.select([ibis.row_number().name("row_num")]) - yaml_dict = compiler.compile_to_yaml(expr) - reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + reconstructed_expr = compiler.from_yaml(yaml_dict) assert expr.equals(reconstructed_expr) @@ -78,8 +78,8 @@ def test_row_number_window_roundtrip(compiler, t): .name("row_num") ] ) - yaml_dict = compiler.compile_to_yaml(expr) - reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + reconstructed_expr = compiler.from_yaml(yaml_dict) assert expr.equals(reconstructed_expr) @@ -95,6 +95,6 @@ def test_multiple_rank_expressions_roundtrip(compiler, t): .name("mean_c"), ] ) - yaml_dict = compiler.compile_to_yaml(expr) - reconstructed_expr = compiler.compile_from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr) + reconstructed_expr = compiler.from_yaml(yaml_dict) assert expr.equals(reconstructed_expr) From 6c12bf4e049faff5205b4539b341bef6580cd9ea Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 17 Feb 2025 14:17:40 -0500 Subject: [PATCH 12/45] chore: change imports for rebase --- python/letsql/ibis_yaml/compiler.py | 4 ++-- python/letsql/ibis_yaml/sql.py | 7 +++---- python/letsql/ibis_yaml/tests/conftest.py | 9 +++++---- python/letsql/ibis_yaml/tests/test_arithmetic.py | 2 +- python/letsql/ibis_yaml/tests/test_basic.py | 2 +- python/letsql/ibis_yaml/tests/test_compiler.py | 2 +- python/letsql/ibis_yaml/tests/test_join_chain.py | 3 ++- .../ibis_yaml/tests/test_operations_boolean.py | 2 +- .../letsql/ibis_yaml/tests/test_operations_cast.py | 2 +- .../ibis_yaml/tests/test_operations_datetime.py | 6 +++--- python/letsql/ibis_yaml/tests/test_relations.py | 2 +- python/letsql/ibis_yaml/tests/test_selection.py | 2 +- python/letsql/ibis_yaml/tests/test_sql.py | 2 +- python/letsql/ibis_yaml/tests/test_string_ops.py | 2 +- python/letsql/ibis_yaml/tests/test_subquery.py | 4 ++-- python/letsql/ibis_yaml/tests/test_udf.py | 2 +- .../ibis_yaml/tests/test_window_functions.py | 2 +- python/letsql/ibis_yaml/translate.py | 14 +++++++------- python/letsql/ibis_yaml/utils.py | 2 +- 19 files changed, 36 insertions(+), 35 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 06f866ac..61518e95 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -3,10 +3,9 @@ from typing import Any, Dict import dask -import ibis.expr.types as ir import yaml -from ibis.common.collections import FrozenOrderedDict +import letsql.vendor.ibis.expr.types as ir from letsql.ibis_yaml.sql import generate_sql_plans from letsql.ibis_yaml.translate import ( SchemaRegistry, @@ -14,6 +13,7 @@ translate_to_yaml, ) from letsql.ibis_yaml.utils import freeze +from letsql.vendor.ibis.common.collections import FrozenOrderedDict # is this the right way to handle this? or the right place diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 6f7e19f4..90d1536d 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -1,9 +1,8 @@ from typing import Any, Dict, TypedDict -import ibis -import ibis.expr.operations as ops -import ibis.expr.types as ir - +import letsql.vendor.ibis as ibis +import letsql.vendor.ibis.expr.operations as ops +import letsql.vendor.ibis.expr.types as ir from letsql.expr.relations import RemoteTable diff --git a/python/letsql/ibis_yaml/tests/conftest.py b/python/letsql/ibis_yaml/tests/conftest.py index 8c56c6ff..384380a5 100644 --- a/python/letsql/ibis_yaml/tests/conftest.py +++ b/python/letsql/ibis_yaml/tests/conftest.py @@ -1,9 +1,10 @@ from datetime import date -import ibis -import ibis.expr.datatypes as dt import pytest +import letsql.vendor.ibis as ibis +import letsql.vendor.ibis.expr.datatypes as dt + # Fixtures from: https://github.com/ibis-project/ibis-substrait/blob/main/ibis_substrait/tests/compiler/test_tpch.py @@ -249,8 +250,8 @@ def tpc_h03(customer, orders, lineitem): @pytest.fixture def tpc_h04(orders, lineitem): - from ibis import _ - from ibis.expr.operations import ExistsSubquery + from letsql.vendor.ibis import _ + from letsql.vendor.ibis.expr.operations import ExistsSubquery lineitem_filtered = lineitem.filter( [ diff --git a/python/letsql/ibis_yaml/tests/test_arithmetic.py b/python/letsql/ibis_yaml/tests/test_arithmetic.py index 823fbf0f..37d0b4f7 100644 --- a/python/letsql/ibis_yaml/tests/test_arithmetic.py +++ b/python/letsql/ibis_yaml/tests/test_arithmetic.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_add(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_basic.py b/python/letsql/ibis_yaml/tests/test_basic.py index 11f33ee4..ce21f768 100644 --- a/python/letsql/ibis_yaml/tests/test_basic.py +++ b/python/letsql/ibis_yaml/tests/test_basic.py @@ -1,7 +1,7 @@ import datetime import decimal -import ibis +import letsql.vendor.ibis as ibis def test_unbound_table(t, compiler): diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index b065ce97..079fdbf0 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -3,10 +3,10 @@ import dask import yaml -from ibis.common.collections import FrozenOrderedDict import letsql as ls from letsql.ibis_yaml.compiler import ArtifactStore, BuildManager +from letsql.vendor.ibis.common.collections import FrozenOrderedDict def test_build_manager_expr_hash(t, build_dir): diff --git a/python/letsql/ibis_yaml/tests/test_join_chain.py b/python/letsql/ibis_yaml/tests/test_join_chain.py index 88e5e4e1..9da441c9 100644 --- a/python/letsql/ibis_yaml/tests/test_join_chain.py +++ b/python/letsql/ibis_yaml/tests/test_join_chain.py @@ -1,6 +1,7 @@ -import ibis import pytest +import letsql.vendor.ibis as ibis + @pytest.fixture def orders(): diff --git a/python/letsql/ibis_yaml/tests/test_operations_boolean.py b/python/letsql/ibis_yaml/tests/test_operations_boolean.py index 922519f7..637131e5 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_boolean.py +++ b/python/letsql/ibis_yaml/tests/test_operations_boolean.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_equals(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_operations_cast.py b/python/letsql/ibis_yaml/tests/test_operations_cast.py index 2f6b5264..03eb6219 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_cast.py +++ b/python/letsql/ibis_yaml/tests/test_operations_cast.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_explicit_cast(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_operations_datetime.py b/python/letsql/ibis_yaml/tests/test_operations_datetime.py index 68ae1f8f..e1339252 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_datetime.py +++ b/python/letsql/ibis_yaml/tests/test_operations_datetime.py @@ -2,9 +2,9 @@ from datetime import datetime -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations.temporal as tm +import letsql.vendor.ibis as ibis +import letsql.vendor.ibis.expr.datatypes as dt +import letsql.vendor.ibis.expr.operations.temporal as tm def test_date_extract(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_relations.py b/python/letsql/ibis_yaml/tests/test_relations.py index a028432f..3c57cd2d 100644 --- a/python/letsql/ibis_yaml/tests/test_relations.py +++ b/python/letsql/ibis_yaml/tests/test_relations.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_filter(compiler, t): diff --git a/python/letsql/ibis_yaml/tests/test_selection.py b/python/letsql/ibis_yaml/tests/test_selection.py index 567cf145..ba04e9ad 100644 --- a/python/letsql/ibis_yaml/tests/test_selection.py +++ b/python/letsql/ibis_yaml/tests/test_selection.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_selection_on_view(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_sql.py b/python/letsql/ibis_yaml/tests/test_sql.py index dc735217..2b158d8e 100644 --- a/python/letsql/ibis_yaml/tests/test_sql.py +++ b/python/letsql/ibis_yaml/tests/test_sql.py @@ -1,7 +1,7 @@ -import ibis.expr.operations as ops import pytest import letsql as ls +import letsql.vendor.ibis.expr.operations as ops from letsql.expr.relations import RemoteTable, into_backend from letsql.ibis_yaml.sql import find_remote_tables, generate_sql_plans diff --git a/python/letsql/ibis_yaml/tests/test_string_ops.py b/python/letsql/ibis_yaml/tests/test_string_ops.py index f4abb7d0..4af5d59e 100644 --- a/python/letsql/ibis_yaml/tests/test_string_ops.py +++ b/python/letsql/ibis_yaml/tests/test_string_ops.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_string_concat(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_subquery.py b/python/letsql/ibis_yaml/tests/test_subquery.py index 6d27393f..2b75fa08 100644 --- a/python/letsql/ibis_yaml/tests/test_subquery.py +++ b/python/letsql/ibis_yaml/tests/test_subquery.py @@ -1,5 +1,5 @@ -import ibis -import ibis.expr.operations as ops +import letsql.vendor.ibis as ibis +import letsql.vendor.ibis.expr.operations as ops def test_scalar_subquery(compiler, t): diff --git a/python/letsql/ibis_yaml/tests/test_udf.py b/python/letsql/ibis_yaml/tests/test_udf.py index b5e3b5bb..5429dd87 100644 --- a/python/letsql/ibis_yaml/tests/test_udf.py +++ b/python/letsql/ibis_yaml/tests/test_udf.py @@ -1,8 +1,8 @@ -import ibis import pytest import letsql.ibis_yaml import letsql.ibis_yaml.utils +import letsql.vendor.ibis as ibis def test_built_in_udf_properties(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/letsql/ibis_yaml/tests/test_window_functions.py index f0813f14..aeb11066 100644 --- a/python/letsql/ibis_yaml/tests/test_window_functions.py +++ b/python/letsql/ibis_yaml/tests/test_window_functions.py @@ -1,4 +1,4 @@ -import ibis +import letsql.vendor.ibis as ibis def test_window_function_roundtrip(compiler, t): diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 0af15a70..bab90a10 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -5,16 +5,15 @@ import functools from typing import Any -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.operations.temporal as tm -import ibis.expr.rules as rlz -import ibis.expr.types as ir import pyarrow.parquet as pq -from ibis.common.annotations import Argument import letsql as ls +import letsql.vendor.ibis as ibis +import letsql.vendor.ibis.expr.datatypes as dt +import letsql.vendor.ibis.expr.operations as ops +import letsql.vendor.ibis.expr.operations.temporal as tm +import letsql.vendor.ibis.expr.rules as rlz +import letsql.vendor.ibis.expr.types as ir from letsql.expr.relations import CachedNode, RemoteTable, into_backend from letsql.ibis_yaml.utils import ( deserialize_udf_function, @@ -23,6 +22,7 @@ serialize_udf_function, translate_storage, ) +from letsql.vendor.ibis.common.annotations import Argument FROM_YAML_HANDLERS: dict[str, Any] = {} diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index a791d0f2..4d746c97 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -3,9 +3,9 @@ from typing import Any, Dict import cloudpickle -from ibis.common.collections import FrozenOrderedDict from letsql.common.caching import SourceStorage +from letsql.vendor.ibis.common.collections import FrozenOrderedDict def serialize_udf_function(fn: callable) -> str: From dc34584114ee447d70660d6e3ddf3cd67e93abfc Mon Sep 17 00:00:00 2001 From: dlovell Date: Sun, 16 Feb 2025 08:26:49 -0500 Subject: [PATCH 13/45] feat: add Profiles --- python/xorq/vendor/ibis/backends/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/xorq/vendor/ibis/backends/__init__.py b/python/xorq/vendor/ibis/backends/__init__.py index 24f634e3..4d76e9bb 100644 --- a/python/xorq/vendor/ibis/backends/__init__.py +++ b/python/xorq/vendor/ibis/backends/__init__.py @@ -34,6 +34,16 @@ from xorq.common.utils.inspect_utils import get_arguments from xorq.vendor import ibis from xorq.vendor.ibis import util +import dask +import toolz +from attr import ( + field, + frozen, +) +from attr.validators import ( + instance_of, + optional, +) if TYPE_CHECKING: From 686548c4d0f09c108dbdb90f4d93bfb2e265aa5f Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 17 Feb 2025 16:12:41 -0500 Subject: [PATCH 14/45] feat: use Profile in YamlExpressionCompiler --- python/letsql/ibis_yaml/compiler.py | 26 +++++++++++-- python/letsql/ibis_yaml/sql.py | 14 +------ .../letsql/ibis_yaml/tests/test_compiler.py | 14 ++----- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 22 +++++------ python/letsql/ibis_yaml/tests/test_sql.py | 29 -------------- python/letsql/ibis_yaml/translate.py | 30 +++----------- python/letsql/ibis_yaml/utils.py | 39 ++++++++++++++++++- 7 files changed, 80 insertions(+), 94 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 61518e95..ac8164e5 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -12,7 +12,8 @@ translate_from_yaml, translate_to_yaml, ) -from letsql.ibis_yaml.utils import freeze +from letsql.ibis_yaml.utils import find_remote_backends, freeze +from letsql.vendor.ibis.backends import Profile from letsql.vendor.ibis.common.collections import FrozenOrderedDict @@ -138,20 +139,39 @@ def compile_expr(self, expr: ir.Expr) -> None: expr_hash = self.artifact_store.get_expr_hash(expr) current_path = self.artifact_store.get_build_path(expr_hash) + backends = (expr._find_backend(), *find_remote_backends(expr.op())) + profiles = { + backend._profile.hash_name: backend._profile.as_dict() + for backend in backends + } + translator = YamlExpressionTranslator( - profiles=self.profiles, current_path=current_path + profiles=profiles, current_path=current_path ) # metadata.yaml (uv.lock, git commit version, version==xorq_internal_version, user, hostname, ip_address(host ip)) yaml_dict = translator.to_yaml(expr) self.artifact_store.save_yaml(yaml_dict, expr_hash, "expr.yaml") + self.artifact_store.save_yaml(profiles, expr_hash, "profiles.yaml") + sql_plans = generate_sql_plans(expr) self.artifact_store.save_yaml(sql_plans, expr_hash, "sql.yaml") def load_expr(self, expr_hash: str) -> ir.Expr: build_path = self.artifact_store.get_build_path(expr_hash) + profiles_dict = self.artifact_store.load_yaml(expr_hash, "profiles.yaml") + + def f(values): + dct = dict(values) + dct["kwargs_tuple"] = tuple(map(tuple, dct["kwargs_tuple"])) + return dct + + profiles = { + profile: Profile(**f(values)).get_con() + for profile, values in profiles_dict.items() + } translator = YamlExpressionTranslator( - current_path=build_path, profiles=self.profiles + current_path=build_path, profiles=profiles ) yaml_dict = self.artifact_store.load_yaml(expr_hash, "expr.yaml") diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 90d1536d..24a00874 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -29,16 +29,9 @@ def traverse(node): if isinstance(node, ops.Node) and isinstance(node, RemoteTable): remote_expr = node.remote_expr original_backend = remote_expr._find_backend() - if ( - not hasattr(original_backend, "profile_name") - or original_backend.profile_name is None - ): - raise AttributeError( - "Backend does not have a valid 'profile_name' attribute." - ) engine_name = original_backend.name - profile_name = original_backend.profile_name + profile_name = original_backend._profile.hash_name remote_tables[node.name] = { "engine": engine_name, "profile_name": profile_name, @@ -69,11 +62,8 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: main_sql = ibis.to_sql(expr) backend = expr._find_backend() - if not hasattr(backend, "profile_name") or backend.profile_name is None: - raise AttributeError("Backend does not have a valid 'profile_name' attribute.") - engine_name = backend.name - profile_name = backend.profile_name + profile_name = backend._profile.hash_name plans: SQLPlans = { "queries": { diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index 079fdbf0..58e001da 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -2,6 +2,7 @@ import pathlib import dask +import pytest import yaml import letsql as ls @@ -60,11 +61,8 @@ def test_clean_frozen_dict_yaml(build_dir): def test_ibis_compiler(t, build_dir): t = ls.memtable({"a": [0, 1], "b": [0, 1]}) - backend = t._find_backend() - backend.profile_name = "default" expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) - compiler.profiles = {"default": backend} compiler.compile_expr(expr) expr_hash = dask.base.tokenize(expr)[:12] @@ -73,9 +71,9 @@ def test_ibis_compiler(t, build_dir): assert expr.execute().equals(roundtrip_expr.execute()) +@pytest.mark.xfail def test_ibis_compiler_parquet_reader(t, build_dir): backend = ls.datafusion.connect() - backend.profile_name = "default" awards_players = backend.read_parquet( ls.config.options.pins.get_path("awards_players"), table_name="awards_players", @@ -83,7 +81,6 @@ def test_ibis_compiler_parquet_reader(t, build_dir): expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) - compiler.profiles = {"default": backend} compiler.compile_expr(expr) expr_hash = "5ebaf6a7a02d" roundtrip_expr = compiler.load_expr(expr_hash) @@ -91,9 +88,10 @@ def test_ibis_compiler_parquet_reader(t, build_dir): assert expr.execute().equals(roundtrip_expr.execute()) +# TODO: how to not use parquet reader or used deferred read +@pytest.mark.xfail def test_compiler_sql(build_dir): backend = ls.datafusion.connect() - backend.profile_name = "default" awards_players = backend.read_parquet( ls.config.options.pins.get_path("awards_players"), table_name="awards_players", @@ -101,7 +99,6 @@ def test_compiler_sql(build_dir): expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) - compiler.profiles = {"default": backend} compiler.compile_expr(expr) expr_hash = "5ebaf6a7a02d" _roundtrip_expr = compiler.load_expr(expr_hash) @@ -126,11 +123,8 @@ def test_compiler_sql(build_dir): def test_ibis_compiler_expr_schema_ref(t, build_dir): t = ls.memtable({"a": [0, 1], "b": [0, 1]}) - backend = t._find_backend() - backend.profile_name = "default" expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) - compiler.profiles = {"default": backend} compiler.compile_expr(expr) expr_hash = dask.base.tokenize(expr)[:12] diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index 62a7ffbd..0c773e33 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -37,7 +37,7 @@ def prepare_duckdb_con(duckdb_path): def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): con = prepare_duckdb_con - profiles = {"my_duckdb": con} + profiles = {con._profile.hash_name: con} table_expr = con.table("mytable") @@ -59,10 +59,9 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): def test_memtable(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() - backend.profile_name = "default-duckdb" expr = table.mutate(new_val=2 * ls._.val) - profiles = {"default-duckdb": backend} + profiles = {backend._profile.hash_name: backend} compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) @@ -77,21 +76,18 @@ def test_memtable(prepare_duckdb_con, build_dir): def test_into_backend(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() - backend.profile_name = "default-duckdb" expr = table.mutate(new_val=2 * ls._.val) con2 = ls.connect() - con2.profile_name = "default-let" con3 = ls.connect() - con3.profile_name = "default-datafusion" expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * ls._.val) expr = into_backend(expr, con3, "df_mem") profiles = { - "default-duckdb": backend, - "default-let": con2, - "default-datafusion": con3, + backend._profile.hash_name: backend, + con2._profile.hash_name: con2, + con3._profile.hash_name: con3, } compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) @@ -105,15 +101,15 @@ def test_into_backend(prepare_duckdb_con, build_dir): def test_memtable_cache(prepare_duckdb_con, build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() - backend.profile_name = "default-duckdb" expr = table.mutate(new_val=2 * ls._.val).cache() backend1 = expr._find_backend() - backend1.profile_name = "default-let" - profiles = {"default-duckdb": backend, "default-let": backend1} + profiles = { + backend._profile.hash_name: backend, + backend1._profile.hash_name: backend1, + } compiler = YamlExpressionTranslator(profiles=profiles, current_path=build_dir) - compiler.profiles = profiles yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) diff --git a/python/letsql/ibis_yaml/tests/test_sql.py b/python/letsql/ibis_yaml/tests/test_sql.py index 2b158d8e..c1593e19 100644 --- a/python/letsql/ibis_yaml/tests/test_sql.py +++ b/python/letsql/ibis_yaml/tests/test_sql.py @@ -1,5 +1,3 @@ -import pytest - import letsql as ls import letsql.vendor.ibis.expr.operations as ops from letsql.expr.relations import RemoteTable, into_backend @@ -22,23 +20,6 @@ def test_find_remote_tables_simple(): assert remote_tables[table_name]["engine"] == "duckdb" -def test_find_remote_tables_raises(): - db = ls.connect() - - awards_players = db.read_parquet( - ls.config.options.pins.get_path("awards_players"), - table_name="awards_players", - ) - - db2 = ls.datafusion.connect() - - remote_expr = into_backend(awards_players, db2) - with pytest.raises( - AttributeError, match="Backend does not have a valid 'profile_name' attribute." - ): - find_remote_tables(remote_expr.op()) - - def test_find_remote_tables_nested(): db1 = ls.duckdb.connect() db1.profile_name = "duckdb" @@ -112,16 +93,6 @@ def test_generate_sql_plans_simple(): assert all("sql" in q and "engine" in q for q in plans["queries"].values()) -def test_generate_sql_plans_raises(): - db = ls.duckdb.connect() - table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) - expr = into_backend(table, db).filter(ls._.id > 1) - with pytest.raises( - AttributeError, match="Backend does not have a valid 'profile_name' attribute." - ): - generate_sql_plans(expr) - - def test_generate_sql_plans_complex_example(): pg = ls.postgres.connect_examples() pg.profile_name = "postgres" diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index bab90a10..a419fabe 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -288,7 +288,7 @@ def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.DatabaseTable) def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: - profile_name = getattr(op.source, "profile_name", None) + profile_name = op.source._profile.hash_name schema_id = compiler.schema_registry.register_schema(op.schema) return freeze( @@ -325,13 +325,14 @@ def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: @translate_to_yaml.register(CachedNode) def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: schema_id = compiler.schema_registry.register_schema(op.schema) + # source should be called profile_name return freeze( { "op": "CachedNode", "schema_ref": schema_id, "parent": translate_to_yaml(op.parent, compiler), - "source": getattr(op.source, "profile_name", None), + "source": op.source._profile.hash_name, "storage": translate_storage(op.storage, compiler), "values": dict(op.values), } @@ -393,27 +394,6 @@ def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: ) -@register_from_yaml_handler("InMemoryTable") -def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - - file_path = yaml_dict["file"] - schema_ref = yaml_dict["schema_ref"] - try: - schema_def = compiler.definitions["schemas"][schema_ref] - except KeyError: - raise ValueError(f"Schema {schema_ref} not found in definitions") - - arrow_table = pq.read_table(file_path) - df = arrow_table.to_pandas() - table_name = yaml_dict.get("table", "memtable") - - column_names = list(schema_def.keys()) - memtable_expr = ls.memtable(df, columns=column_names, name=table_name) - return memtable_expr - - @register_from_yaml_handler("InMemoryTable") def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: file_path = yaml_dict["file"] @@ -428,10 +408,10 @@ def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: @translate_to_yaml.register(RemoteTable) def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: - profile_name = getattr(op.source, "profile_name", None) + profile_name = op.source._profile.hash_name remote_expr_yaml = translate_to_yaml(op.remote_expr, compiler) schema_id = compiler.schema_registry.register_schema(op.schema) - + # TODO: change profile to profile_name return freeze( { "op": "RemoteTable", diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index 4d746c97..6985feaf 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -1,10 +1,12 @@ import base64 from collections.abc import Mapping, Sequence -from typing import Any, Dict +from typing import Any, Dict, Tuple import cloudpickle +import letsql.vendor.ibis.expr.operations as ops from letsql.common.caching import SourceStorage +from letsql.expr.relations import RemoteTable from letsql.vendor.ibis.common.collections import FrozenOrderedDict @@ -134,7 +136,7 @@ def translate_storage(storage, compiler: Any) -> Dict: if isinstance(storage, SourceStorage): return { "type": "SourceStorage", - "source": getattr(storage.source, "profile_name", None), + "source": storage.source._profile.hash_name, } else: raise NotImplementedError(f"Unknown storage type: {type(storage)}") @@ -152,3 +154,36 @@ def load_storage_from_yaml(storage_yaml: Dict, compiler: Any): return SourceStorage(source=source) else: raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") + + +def find_remote_backends(op) -> Tuple: + remote_backends = () + seen = set() + + def traverse(node): + nonlocal remote_backends + if node is None or id(node) in seen: + return + + seen.add(id(node)) + + if isinstance(node, ops.Node) and isinstance(node, RemoteTable): + remote_expr = node.remote_expr + original_backend = remote_expr._find_backend() + remote_backends += (original_backend,) + + if isinstance(node, ops.Node): + for arg in node.args: + if isinstance(arg, ops.Node): + traverse(arg) + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, ops.Node): + traverse(item) + elif isinstance(arg, dict): + for v in arg.values(): + if isinstance(v, ops.Node): + traverse(v) + + traverse(op) + return remote_backends From 7772102e37adb18b7a478a995a9faaeb1c70a16c Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Tue, 18 Feb 2025 12:07:19 -0500 Subject: [PATCH 15/45] feat: add Read support for yaml roundtrip --- .../letsql/ibis_yaml/tests/test_letsql_ops.py | 26 ++++++++-- python/letsql/ibis_yaml/translate.py | 48 ++++++++++++++++++- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index 0c773e33..bff78103 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -1,6 +1,10 @@ import pytest import letsql as ls +from letsql import _ +from letsql.common.utils.defer_utils import ( + deferred_read_csv, +) from letsql.expr.relations import into_backend from letsql.ibis_yaml.compiler import YamlExpressionTranslator @@ -56,7 +60,7 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): assert df_original.equals(df_roundtrip), "Roundtrip expression data differs!" -def test_memtable(prepare_duckdb_con, build_dir): +def test_memtable(build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() expr = table.mutate(new_val=2 * ls._.val) @@ -73,7 +77,7 @@ def test_memtable(prepare_duckdb_con, build_dir): assert expr.execute().equals(roundtrip_expr.execute()) -def test_into_backend(prepare_duckdb_con, build_dir): +def test_into_backend(build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() expr = table.mutate(new_val=2 * ls._.val) @@ -98,7 +102,7 @@ def test_into_backend(prepare_duckdb_con, build_dir): assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) -def test_memtable_cache(prepare_duckdb_con, build_dir): +def test_memtable_cache(build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() expr = table.mutate(new_val=2 * ls._.val).cache() @@ -115,3 +119,19 @@ def test_memtable_cache(prepare_duckdb_con, build_dir): roundtrip_expr = compiler.from_yaml(yaml_dict) assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) + + +def test_deferred_read_csv(build_dir): + csv_name = "iris" + csv_path = ls.options.pins.get_path(csv_name) + pd_con = ls.pandas.connect() + expr = deferred_read_csv(con=pd_con, path=csv_path, table_name=csv_name).filter( + _.sepal_length > 6 + ) + + profiles = {pd_con._profile.hash_name: pd_con} + compiler = YamlExpressionTranslator(profiles=profiles, current_path=build_dir) + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) + + assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index a419fabe..d3fd4915 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -14,7 +14,7 @@ import letsql.vendor.ibis.expr.operations.temporal as tm import letsql.vendor.ibis.expr.rules as rlz import letsql.vendor.ibis.expr.types as ir -from letsql.expr.relations import CachedNode, RemoteTable, into_backend +from letsql.expr.relations import CachedNode, Read, RemoteTable, into_backend from letsql.ibis_yaml.utils import ( deserialize_udf_function, freeze, @@ -443,6 +443,52 @@ def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: return remote_table_expr +@translate_to_yaml.register(Read) +def _read_to_yaml(op: Read, compiler: Any) -> dict: + schema_id = compiler.schema_registry.register_schema(op.schema) + profile_hash_name = ( + op.source._profile.hash_name if hasattr(op.source, "_profile") else None + ) + + return freeze( + { + "op": "Read", + "method_name": op.method_name, + "name": op.name, + "schema_ref": schema_id, + "profile": profile_hash_name, + "read_kwargs": dict(op.read_kwargs) if op.read_kwargs else {}, + } + ) + + +@register_from_yaml_handler("Read") +def _read_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: + if not hasattr(compiler, "definitions"): + raise ValueError("Compiler missing definitions with schemas") + + schema_ref = yaml_dict["schema_ref"] + schema_def = compiler.definitions["schemas"][schema_ref] + + schema = { + name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() + } + + profile_hash_name = yaml_dict.get("profile") + + source = compiler.profiles[profile_hash_name] + + read_op = Read( + method_name=yaml_dict["method_name"], + name=yaml_dict["name"], + schema=schema, + source=source, + read_kwargs=yaml_dict.get("read_kwargs", {}), + ) + + return read_op.to_expr() + + @translate_to_yaml.register(ops.Literal) def _literal_to_yaml(op: ops.Literal, compiler: Any) -> dict: value = _translate_literal_value(op.value, op.dtype) From 050d818aa0c7dff61fde991a8d79e7724543484d Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Tue, 18 Feb 2025 19:08:09 -0500 Subject: [PATCH 16/45] chore: unmark xfail tests for deferred reads - add `find_all_backends` method with Read Op support --- python/letsql/ibis_yaml/compiler.py | 6 ++- .../letsql/ibis_yaml/tests/test_compiler.py | 33 +++++++------- python/letsql/ibis_yaml/utils.py | 44 ++++++++++++------- 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index ac8164e5..bf5ae089 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -12,7 +12,7 @@ translate_from_yaml, translate_to_yaml, ) -from letsql.ibis_yaml.utils import find_remote_backends, freeze +from letsql.ibis_yaml.utils import find_all_backends, freeze from letsql.vendor.ibis.backends import Profile from letsql.vendor.ibis.common.collections import FrozenOrderedDict @@ -139,12 +139,14 @@ def compile_expr(self, expr: ir.Expr) -> None: expr_hash = self.artifact_store.get_expr_hash(expr) current_path = self.artifact_store.get_build_path(expr_hash) - backends = (expr._find_backend(), *find_remote_backends(expr.op())) + backends = find_all_backends(expr.op()) profiles = { backend._profile.hash_name: backend._profile.as_dict() for backend in backends } + print(profiles) + translator = YamlExpressionTranslator( profiles=profiles, current_path=current_path ) diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index 58e001da..c7377469 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -2,10 +2,10 @@ import pathlib import dask -import pytest import yaml import letsql as ls +from letsql.common.utils.defer_utils import deferred_read_parquet from letsql.ibis_yaml.compiler import ArtifactStore, BuildManager from letsql.vendor.ibis.common.collections import FrozenOrderedDict @@ -71,28 +71,26 @@ def test_ibis_compiler(t, build_dir): assert expr.execute().equals(roundtrip_expr.execute()) -@pytest.mark.xfail -def test_ibis_compiler_parquet_reader(t, build_dir): - backend = ls.datafusion.connect() - awards_players = backend.read_parquet( - ls.config.options.pins.get_path("awards_players"), - table_name="awards_players", +def test_ibis_compiler_parquet_reader(build_dir): + backend = ls.duckdb.connect() + parquet_path = ls.config.options.pins.get_path("awards_players") + awards_players = deferred_read_parquet( + backend, parquet_path, table_name="award_players" ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") - compiler = BuildManager(build_dir) compiler.compile_expr(expr) - expr_hash = "5ebaf6a7a02d" + print(dask.base.tokenize(expr)[:12]) + expr_hash = "9a7d0b20d41a" roundtrip_expr = compiler.load_expr(expr_hash) assert expr.execute().equals(roundtrip_expr.execute()) -# TODO: how to not use parquet reader or used deferred read -@pytest.mark.xfail def test_compiler_sql(build_dir): backend = ls.datafusion.connect() - awards_players = backend.read_parquet( + awards_players = deferred_read_parquet( + backend, ls.config.options.pins.get_path("awards_players"), table_name="awards_players", ) @@ -100,7 +98,7 @@ def test_compiler_sql(build_dir): compiler = BuildManager(build_dir) compiler.compile_expr(expr) - expr_hash = "5ebaf6a7a02d" + expr_hash = "79d83e9c89ad" _roundtrip_expr = compiler.load_expr(expr_hash) assert os.path.exists(build_dir / expr_hash / "sql.yaml") @@ -109,13 +107,12 @@ def test_compiler_sql(build_dir): expected_result = ( "queries:\n" " main:\n" - " engine: datafusion\n" - " profile_name: default\n" + " engine: let\n" + f" profile_name: {expr._find_backend()._profile.hash_name}\n" ' sql: "SELECT\\n \\"t0\\".\\"playerID\\",\\n ' '\\"t0\\".\\"awardID\\",\\n \\"t0\\".\\"tie\\"\\\n' - ' ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players\\" AS ' - '\\"t0\\"\\nWHERE\\n \\"t0\\".\\"\\\n' - " lgID\\\" = 'NL'\"\n" + ' ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\"\\' + '\n \\ AS \\"t0\\"\\nWHERE\\n \\"t0\\".\\"lgID\\" = \'NL\'"\n' ) assert sql_text == expected_result diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index 6985feaf..2cf5e435 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -5,9 +5,12 @@ import cloudpickle import letsql.vendor.ibis.expr.operations as ops +import letsql.vendor.ibis.expr.types as ir from letsql.common.caching import SourceStorage -from letsql.expr.relations import RemoteTable +from letsql.expr.relations import CachedNode, Read +from letsql.vendor.ibis.backends import BaseBackend from letsql.vendor.ibis.common.collections import FrozenOrderedDict +from letsql.vendor.ibis.expr.types.relations import Table def serialize_udf_function(fn: callable) -> str: @@ -145,32 +148,38 @@ def translate_storage(storage, compiler: Any) -> Dict: def load_storage_from_yaml(storage_yaml: Dict, compiler: Any): if storage_yaml["type"] == "SourceStorage": source_profile_name = storage_yaml["source"] - try: - source = compiler.profiles[source_profile_name] - except KeyError: - raise ValueError( - f"Source profile {source_profile_name!r} not found in compiler.profiles" - ) + source = compiler.profiles[source_profile_name] return SourceStorage(source=source) else: raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") -def find_remote_backends(op) -> Tuple: - remote_backends = () +def find_all_backends(expr: ir.Expr) -> Tuple[BaseBackend, ...]: + backends = set() seen = set() def traverse(node): - nonlocal remote_backends if node is None or id(node) in seen: return - seen.add(id(node)) - if isinstance(node, ops.Node) and isinstance(node, RemoteTable): - remote_expr = node.remote_expr - original_backend = remote_expr._find_backend() - remote_backends += (original_backend,) + if isinstance(node, Table): + traverse(node.op()) + return + + if isinstance(node, Read): + backend = node.source + if backend is not None: + backends.add(backend) + + elif isinstance(node, ops.DatabaseTable): + backends.add(node.source) + + elif isinstance(node, ops.SQLQueryResult): # caching_utils uses + backends.add(node.source) + + elif isinstance(node, CachedNode): + backends.add(node.source) if isinstance(node, ops.Node): for arg in node.args: @@ -185,5 +194,6 @@ def traverse(node): if isinstance(v, ops.Node): traverse(v) - traverse(op) - return remote_backends + traverse(expr) + + return tuple(backends) From d9762470198b459d6ac8595def92e6f96fc790ae Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Wed, 19 Feb 2025 09:29:31 -0500 Subject: [PATCH 17/45] wip --- python/letsql/ibis_yaml/sql.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 24a00874..597301ee 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -3,7 +3,7 @@ import letsql.vendor.ibis as ibis import letsql.vendor.ibis.expr.operations as ops import letsql.vendor.ibis.expr.types as ir -from letsql.expr.relations import RemoteTable +from letsql.expr.relations import Read, RemoteTable class QueryInfo(TypedDict): @@ -37,6 +37,17 @@ def traverse(node): "profile_name": profile_name, "sql": ibis.to_sql(remote_expr), } + if isinstance(node, Read): + backend = node.source + if backend is not None: + engine_name = backend.name + profile_name = backend._profile.hash_name + + remote_tables[node.name] = { + "engine": engine_name, + "profile_name": profile_name, + "sql": ibis.to_sql(node.make_unbound_dt().to_expr()), + } if isinstance(node, ops.Node): for arg in node.args: From b1368975cd79773fd07b43ae911a1306d0736ff0 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Wed, 19 Feb 2025 12:50:13 -0500 Subject: [PATCH 18/45] feat: add Read op to sql serialization --- python/letsql/ibis_yaml/sql.py | 8 +++- .../letsql/ibis_yaml/tests/test_compiler.py | 15 +++++--- python/letsql/ibis_yaml/utils.py | 37 ++++++++++++++++++- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 597301ee..8e31b5b4 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -4,6 +4,7 @@ import letsql.vendor.ibis.expr.operations as ops import letsql.vendor.ibis.expr.types as ir from letsql.expr.relations import Read, RemoteTable +from letsql.ibis_yaml.utils import find_relations class QueryInfo(TypedDict): @@ -35,6 +36,7 @@ def traverse(node): remote_tables[node.name] = { "engine": engine_name, "profile_name": profile_name, + "relations": find_relations(remote_expr), "sql": ibis.to_sql(remote_expr), } if isinstance(node, Read): @@ -42,10 +44,10 @@ def traverse(node): if backend is not None: engine_name = backend.name profile_name = backend._profile.hash_name - - remote_tables[node.name] = { + remote_tables[node.make_unbound_dt().name] = { "engine": engine_name, "profile_name": profile_name, + "relations": node.make_unbound_dt().name, "sql": ibis.to_sql(node.make_unbound_dt().to_expr()), } @@ -81,6 +83,7 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: "main": { "engine": engine_name, "profile_name": profile_name, + "relations": list(find_relations(expr)), "sql": main_sql.strip(), } } @@ -89,6 +92,7 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: for table_name, info in remote_tables.items(): plans["queries"][table_name] = { "engine": info["engine"], + "relations": info["relations"], "profile_name": info["profile_name"], "sql": info["sql"].strip(), } diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index c7377469..f67e38ea 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -109,12 +109,17 @@ def test_compiler_sql(build_dir): " main:\n" " engine: let\n" f" profile_name: {expr._find_backend()._profile.hash_name}\n" - ' sql: "SELECT\\n \\"t0\\".\\"playerID\\",\\n ' - '\\"t0\\".\\"awardID\\",\\n \\"t0\\".\\"tie\\"\\\n' - ' ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\"\\' - '\n \\ AS \\"t0\\"\\nWHERE\\n \\"t0\\".\\"lgID\\" = \'NL\'"\n' + " relations:\n" + " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" + ' sql: "SELECT\\n \\"t0\\".\\"playerID\\",\\n \\"t0\\".\\"awardID\\",\\n \\"t0\\".\\"tie\\"' + '\\\n ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\"' + '\\\n \\ AS \\"t0\\"\\nWHERE\\n \\"t0\\".\\"lgID\\" = \'NL\'"\n' + " awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f:\n" + " engine: datafusion\n" + " relations: awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" + " profile_name: a506210f56203e8f9b4a84ef73d95eaa\n" + ' sql: "SELECT\\n *\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\""\n' ) - assert sql_text == expected_result diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index 2cf5e435..2d1ffe5f 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -1,13 +1,13 @@ import base64 from collections.abc import Mapping, Sequence -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple import cloudpickle import letsql.vendor.ibis.expr.operations as ops import letsql.vendor.ibis.expr.types as ir from letsql.common.caching import SourceStorage -from letsql.expr.relations import CachedNode, Read +from letsql.expr.relations import CachedNode, Read, RemoteTable from letsql.vendor.ibis.backends import BaseBackend from letsql.vendor.ibis.common.collections import FrozenOrderedDict from letsql.vendor.ibis.expr.types.relations import Table @@ -197,3 +197,36 @@ def traverse(node): traverse(expr) return tuple(backends) + + +def find_relations(expr: ir.Expr) -> List[str]: + relations = [] + seen = set() + + def traverse(node): + if node is None or id(node) in seen: + return + seen.add(id(node)) + + if isinstance(node, ops.Node): + if isinstance(node, RemoteTable): + relations.append(node.name) + elif isinstance(node, Read): + relations.append(node.make_unbound_dt().name) + elif isinstance(node, ops.DatabaseTable): + relations.append(node.name) + + for arg in node.args: + if isinstance(arg, ops.Node): + traverse(arg) + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, ops.Node): + traverse(item) + elif isinstance(arg, dict): + for v in arg.values(): + if isinstance(v, ops.Node): + traverse(v) + + traverse(expr.op()) + return list(dict.fromkeys(relations)) From f3614c88e4100ecb33d800ce9a354b4176647a70 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Wed, 19 Feb 2025 13:25:46 -0500 Subject: [PATCH 19/45] refactor: raise error when UDFs are not of proper input type --- python/letsql/ibis_yaml/tests/test_udf.py | 12 ++++++++++++ python/letsql/ibis_yaml/translate.py | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/python/letsql/ibis_yaml/tests/test_udf.py b/python/letsql/ibis_yaml/tests/test_udf.py index 5429dd87..5ba2cc56 100644 --- a/python/letsql/ibis_yaml/tests/test_udf.py +++ b/python/letsql/ibis_yaml/tests/test_udf.py @@ -31,6 +31,18 @@ def add_one(x: int) -> int: assert orig_arg.dtype == rt_arg.dtype +def test_compiler_raises(compiler): + t = ibis.table({"a": "int64"}, name="t") + + @ibis.udf.scalar.python + def add_one(x: int) -> int: + pass + + expr = t.mutate(new=add_one(t.a)) + with pytest.raises(NotImplementedError): + compiler.to_yaml(expr) + + @pytest.mark.xfail( reason="UDFs do not have the same memory address when pickled/unpickled" ) diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index d3fd4915..a17d1541 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -1049,6 +1049,11 @@ def _searched_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.ScalarUDF) def _scalar_udf_to_yaml(op: ops.ScalarUDF, compiler: Any) -> dict: + print(dir(op)) + if getattr(op.__class__, "__input_type__", None) != ops.udf.InputType.BUILTIN: + raise NotImplementedError( + f"Translation of UDFs with input type {getattr(op.__class__, '__input_type__', None)} is not supported" + ) arg_names = [ name for name in dir(op) From 05d84881d005b57f270a36826b1156420b357899 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Wed, 19 Feb 2025 13:59:38 -0500 Subject: [PATCH 20/45] refactor: remove memtable parquet serialization --- .../letsql/ibis_yaml/tests/test_compiler.py | 2 + .../letsql/ibis_yaml/tests/test_letsql_ops.py | 12 ++++-- python/letsql/ibis_yaml/translate.py | 38 ------------------- 3 files changed, 10 insertions(+), 42 deletions(-) diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/letsql/ibis_yaml/tests/test_compiler.py index f67e38ea..1f5e5cf5 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/letsql/ibis_yaml/tests/test_compiler.py @@ -2,6 +2,7 @@ import pathlib import dask +import pytest import yaml import letsql as ls @@ -59,6 +60,7 @@ def test_clean_frozen_dict_yaml(build_dir): assert expected_yaml == result +@pytest.mark.xfail(reason="MemTable is not serializable") def test_ibis_compiler(t, build_dir): t = ls.memtable({"a": [0, 1], "b": [0, 1]}) expr = t.filter(t.a == 1).drop("b") diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/letsql/ibis_yaml/tests/test_letsql_ops.py index bff78103..3ae918a8 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/letsql/ibis_yaml/tests/test_letsql_ops.py @@ -4,6 +4,7 @@ from letsql import _ from letsql.common.utils.defer_utils import ( deferred_read_csv, + deferred_read_parquet, ) from letsql.expr.relations import into_backend from letsql.ibis_yaml.compiler import YamlExpressionTranslator @@ -60,6 +61,7 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): assert df_original.equals(df_roundtrip), "Roundtrip expression data differs!" +@pytest.mark.xfail(reason="MemTable is not serializable") def test_memtable(build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() @@ -78,14 +80,15 @@ def test_memtable(build_dir): def test_into_backend(build_dir): - table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) - backend = table._find_backend() - expr = table.mutate(new_val=2 * ls._.val) + parquet_path = ls.config.options.pins.get_path("awards_players") + backend = ls.duckdb.connect() + table = deferred_read_parquet(backend, parquet_path, table_name="award_players") + expr = table.mutate(new_id=2 * ls._.playerID) con2 = ls.connect() con3 = ls.connect() - expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * ls._.val) + expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * ls._.new_id) expr = into_backend(expr, con3, "df_mem") profiles = { @@ -102,6 +105,7 @@ def test_into_backend(build_dir): assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) +@pytest.mark.xfail(reason="MemTable is not serializable") def test_memtable_cache(build_dir): table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index a17d1541..72b3615b 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -5,9 +5,6 @@ import functools from typing import Any -import pyarrow.parquet as pq - -import letsql as ls import letsql.vendor.ibis as ibis import letsql.vendor.ibis.expr.datatypes as dt import letsql.vendor.ibis.expr.operations as ops @@ -371,41 +368,6 @@ def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: return op.to_expr() -@translate_to_yaml.register(ops.InMemoryTable) -def _memtable_to_yaml(op: ops.InMemoryTable, compiler: Any) -> dict: - if not hasattr(compiler, "current_path"): - raise ValueError( - "Compiler is missing the 'current_path' attribute for memtable serialization" - ) - - arrow_table = op.data.to_pyarrow(op.schema) - file_path = compiler.current_path / f"memtable_{id(op)}.parquet" - pq.write_table(arrow_table, str(file_path)) - # probably do not need to store schema - schema_id = compiler.schema_registry.register_schema(op.schema) - - return freeze( - { - "op": "InMemoryTable", - "table": op.name, - "schema_ref": schema_id, - "file": str(file_path), - } - ) - - -@register_from_yaml_handler("InMemoryTable") -def _memtable_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: - file_path = yaml_dict["file"] - arrow_table = pq.read_table(file_path) - df = arrow_table.to_pandas() - - table_name = yaml_dict.get("table", "memtable") - - memtable_expr = ls.memtable(df, columns=list(df.columns), name=table_name) - return memtable_expr - - @translate_to_yaml.register(RemoteTable) def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: profile_name = op.source._profile.hash_name From b2ac0604625292090ffc52367945ba3086537fd1 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Thu, 20 Feb 2025 07:46:56 -0500 Subject: [PATCH 21/45] feat: split sql in its own files --- examples/yaml_roundrip.py | 25 +++++++++++++++++++++++++ python/letsql/ibis_yaml/compiler.py | 27 ++++++++++++++++++++++++++- python/letsql/ibis_yaml/sql.py | 16 +++++++++++++++- python/letsql/ibis_yaml/translate.py | 10 ---------- python/letsql/ibis_yaml/utils.py | 4 ++++ 5 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 examples/yaml_roundrip.py diff --git a/examples/yaml_roundrip.py b/examples/yaml_roundrip.py new file mode 100644 index 00000000..e4196b16 --- /dev/null +++ b/examples/yaml_roundrip.py @@ -0,0 +1,25 @@ +import letsql as ls +from letsql.common.utils.defer_utils import deferred_read_parquet +from letsql.expr.relations import into_backend +from letsql.ibis_yaml.compiler import BuildManager + + +pg = ls.postgres.connect_examples() +db = ls.duckdb.connect() + +batting = pg.table("batting") + +backend = ls.duckdb.connect() +awards_players = deferred_read_parquet( + backend, + ls.config.options.pins.get_path("awards_players"), + table_name="award_players", +) +left = batting.filter(batting.yearID == 2015) +right = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") +expr = left.join(into_backend(right, pg), ["playerID"], how="semi")[["yearID", "stint"]] + +build_manager = BuildManager("builds") +build_manager.compile_expr(expr) + +roundtrip_expr = build_manager.load_expr("c6a24bb85380") diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index bf5ae089..0a8b9d0b 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -135,6 +135,30 @@ def __init__(self, build_dir: pathlib.Path): self.artifact_store = ArtifactStore(build_dir) self.profiles = {} + def _write_sql_file(self, sql: str, expr_hash: str, query_name: str) -> str: + sql_hash = dask.base.tokenize(sql)[:12] + filename = f"{sql_hash}.sql" + sql_path = self.artifact_store.get_build_path(expr_hash) / filename + sql_path.write_text(sql) + return filename + + def _process_sql_plans( + self, sql_plans: Dict[str, Any], expr_hash: str + ) -> Dict[str, Any]: + updated_plans = {"queries": {}} + + for query_name, query_info in sql_plans["queries"].items(): + sql_filename = self._write_sql_file( + query_info["sql"], expr_hash, query_name + ) + + updated_query_info = query_info.copy() + updated_query_info["sql_file"] = sql_filename + updated_query_info.pop("sql") + updated_plans["queries"][query_name] = updated_query_info + + return updated_plans + def compile_expr(self, expr: ir.Expr) -> None: expr_hash = self.artifact_store.get_expr_hash(expr) current_path = self.artifact_store.get_build_path(expr_hash) @@ -157,7 +181,8 @@ def compile_expr(self, expr: ir.Expr) -> None: self.artifact_store.save_yaml(profiles, expr_hash, "profiles.yaml") sql_plans = generate_sql_plans(expr) - self.artifact_store.save_yaml(sql_plans, expr_hash, "sql.yaml") + updated_sql_plans = self._process_sql_plans(sql_plans, expr_hash) + self.artifact_store.save_yaml(updated_sql_plans, expr_hash, "sql.yaml") def load_expr(self, expr_hash: str) -> ir.Expr: build_path = self.artifact_store.get_build_path(expr_hash) diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 8e31b5b4..5dcb0ef5 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -17,6 +17,16 @@ class SQLPlans(TypedDict): queries: Dict[str, QueryInfo] +def get_read_options(read_instance): + read_kwargs_list = [{k: v} for k, v in read_instance.read_kwargs] + + return { + "method_name": read_instance.method_name, + "name": read_instance.name, + "read_kwargs": read_kwargs_list, + } + + def find_remote_tables(op) -> Dict[str, Dict[str, Any]]: remote_tables = {} seen = set() @@ -38,6 +48,7 @@ def traverse(node): "profile_name": profile_name, "relations": find_relations(remote_expr), "sql": ibis.to_sql(remote_expr), + "options": {}, } if isinstance(node, Read): backend = node.source @@ -47,8 +58,9 @@ def traverse(node): remote_tables[node.make_unbound_dt().name] = { "engine": engine_name, "profile_name": profile_name, - "relations": node.make_unbound_dt().name, + "relations": [node.make_unbound_dt().name], "sql": ibis.to_sql(node.make_unbound_dt().to_expr()), + "options": get_read_options(node), } if isinstance(node, ops.Node): @@ -85,6 +97,7 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: "profile_name": profile_name, "relations": list(find_relations(expr)), "sql": main_sql.strip(), + "options": {}, } } } @@ -95,6 +108,7 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: "relations": info["relations"], "profile_name": info["profile_name"], "sql": info["sql"].strip(), + "options": info["options"], } return plans diff --git a/python/letsql/ibis_yaml/translate.py b/python/letsql/ibis_yaml/translate.py index 72b3615b..a252e517 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/letsql/ibis_yaml/translate.py @@ -312,10 +312,6 @@ def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: con = compiler.profiles[profile_name] except KeyError: raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") - - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - return con.table(table_name) @@ -640,9 +636,6 @@ def _aggregate_to_yaml(op: ops.Aggregate, compiler: Any) -> dict: @register_from_yaml_handler("Aggregate") def _aggregate_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - parent = translate_from_yaml(yaml_dict["parent"], compiler) groups = tuple( translate_from_yaml(group, compiler) for group in yaml_dict.get("by", []) @@ -835,9 +828,6 @@ def _field_to_yaml(op: ops.Field, compiler: Any) -> dict: @register_from_yaml_handler("Field") def field_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - relation = translate_from_yaml(yaml_dict["relation"], compiler) target_name = yaml_dict["name"] source_name = yaml_dict.get("original_name", target_name) diff --git a/python/letsql/ibis_yaml/utils.py b/python/letsql/ibis_yaml/utils.py index 2d1ffe5f..2b856ff5 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/letsql/ibis_yaml/utils.py @@ -171,6 +171,10 @@ def traverse(node): backend = node.source if backend is not None: backends.add(backend) + elif isinstance(node, RemoteTable): + # this needs to habdle when a RemoteTable has Read op since the backend for the op is + # not the same as _find_backend() + backends.add(*find_all_backends(node.remote_expr)) elif isinstance(node, ops.DatabaseTable): backends.add(node.source) From b4cfe5bb12a53341aeb612859a6c8eceb461358c Mon Sep 17 00:00:00 2001 From: dlovell Date: Wed, 19 Feb 2025 19:43:11 -0500 Subject: [PATCH 22/45] feat: add walk_nodes, find_all_sources to graph_utils --- python/letsql/common/utils/graph_utils.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 python/letsql/common/utils/graph_utils.py diff --git a/python/letsql/common/utils/graph_utils.py b/python/letsql/common/utils/graph_utils.py new file mode 100644 index 00000000..6908c24d --- /dev/null +++ b/python/letsql/common/utils/graph_utils.py @@ -0,0 +1,52 @@ +import xorq.expr.relations as rel + + +def walk_nodes(node_types, expr): + def process_node(op): + match op: + case rel.RemoteTable(): + yield op + yield from walk_nodes( + node_types, + op.remote_expr, + ) + case rel.CachedNode(): + yield op + yield from walk_nodes( + node_types, + op.parent, + ) + case _: + yield from op.find(node_types) + + def inner(rest, seen): + if not rest: + return seen + op = rest.pop() + seen.add(op) + new = process_node(op) + rest.update(set(new).difference(seen)) + return inner(rest, seen) + + rest = process_node(expr.op()) + return inner(set(rest), set()) + + +def find_all_sources(expr): + import xorq.vendor.ibis.expr.operations as ops + + node_types = ( + ops.DatabaseTable, + ops.SQLQueryResult, + rel.CachedNode, + rel.Read, + rel.RemoteTable, + # ExprScalarUDF has an expr we need to get to + # FlightOperator has a dynamically generated connection: it should be passed a Profile instead + ) + nodes = walk_nodes(node_types, expr) + sources = tuple( + source + for (source, _) in set((node.source, node.source._profile) for node in nodes) + ) + return sources From a2f3dcb660dc03b8d2c211ed0a00ab4e4421f979 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Thu, 20 Feb 2025 11:26:14 -0500 Subject: [PATCH 23/45] wip: handle multiple Reads in RemoteTable --- examples/yaml_roundrip.py | 8 +++++--- python/letsql/ibis_yaml/compiler.py | 1 + python/letsql/ibis_yaml/sql.py | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/yaml_roundrip.py b/examples/yaml_roundrip.py index e4196b16..e83ee06a 100644 --- a/examples/yaml_roundrip.py +++ b/examples/yaml_roundrip.py @@ -17,9 +17,11 @@ ) left = batting.filter(batting.yearID == 2015) right = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") -expr = left.join(into_backend(right, pg), ["playerID"], how="semi")[["yearID", "stint"]] +expr = left.join( + into_backend(right, pg, "pg-filtered-table"), ["playerID"], how="semi" +)[["yearID", "stint"]] build_manager = BuildManager("builds") -build_manager.compile_expr(expr) +expr_hash = build_manager.compile_expr(expr) -roundtrip_expr = build_manager.load_expr("c6a24bb85380") +roundtrip_expr = build_manager.load_expr(expr_hash) diff --git a/python/letsql/ibis_yaml/compiler.py b/python/letsql/ibis_yaml/compiler.py index 0a8b9d0b..92d20ef0 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/letsql/ibis_yaml/compiler.py @@ -183,6 +183,7 @@ def compile_expr(self, expr: ir.Expr) -> None: sql_plans = generate_sql_plans(expr) updated_sql_plans = self._process_sql_plans(sql_plans, expr_hash) self.artifact_store.save_yaml(updated_sql_plans, expr_hash, "sql.yaml") + return expr_hash def load_expr(self, expr_hash: str) -> ir.Expr: build_path = self.artifact_store.get_build_path(expr_hash) diff --git a/python/letsql/ibis_yaml/sql.py b/python/letsql/ibis_yaml/sql.py index 5dcb0ef5..b4a18b7a 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/letsql/ibis_yaml/sql.py @@ -4,7 +4,7 @@ import letsql.vendor.ibis.expr.operations as ops import letsql.vendor.ibis.expr.types as ir from letsql.expr.relations import Read, RemoteTable -from letsql.ibis_yaml.utils import find_relations +from letsql.ibis_yaml.utils import find_all_backends, find_relations class QueryInfo(TypedDict): @@ -39,7 +39,9 @@ def traverse(node): if isinstance(node, ops.Node) and isinstance(node, RemoteTable): remote_expr = node.remote_expr - original_backend = remote_expr._find_backend() + original_backend = find_all_backends(remote_expr)[ + 0 + ] # this was _find_backend before engine_name = original_backend.name profile_name = original_backend._profile.hash_name From 36e5d281329cd9d29eecd7e50814fbfdb4d1b8ca Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Thu, 20 Feb 2025 12:47:13 -0500 Subject: [PATCH 24/45] chore: rebase main with xorq name change --- examples/yaml_roundrip.py | 16 +++---- python/{letsql => xorq}/ibis_yaml/__init__.py | 0 python/{letsql => xorq}/ibis_yaml/compiler.py | 12 +++--- python/{letsql => xorq}/ibis_yaml/sql.py | 10 ++--- .../ibis_yaml/tests/__init__.py | 0 .../ibis_yaml/tests/conftest.py | 10 ++--- .../ibis_yaml/tests/test_arithmetic.py | 2 +- .../ibis_yaml/tests/test_basic.py | 2 +- .../ibis_yaml/tests/test_compiler.py | 20 ++++----- .../ibis_yaml/tests/test_join_chain.py | 2 +- .../ibis_yaml/tests/test_letsql_ops.py | 42 +++++++++---------- .../tests/test_operations_boolean.py | 2 +- .../ibis_yaml/tests/test_operations_cast.py | 2 +- .../tests/test_operations_datetime.py | 6 +-- .../ibis_yaml/tests/test_relations.py | 2 +- .../ibis_yaml/tests/test_selection.py | 2 +- .../ibis_yaml/tests/test_sql.py | 38 ++++++++--------- .../ibis_yaml/tests/test_string_ops.py | 2 +- .../ibis_yaml/tests/test_subquery.py | 4 +- .../ibis_yaml/tests/test_tpch.py | 0 .../ibis_yaml/tests/test_udf.py | 8 ++-- .../ibis_yaml/tests/test_window_functions.py | 2 +- .../{letsql => xorq}/ibis_yaml/translate.py | 18 ++++---- python/{letsql => xorq}/ibis_yaml/utils.py | 14 +++---- python/xorq/vendor/ibis/backends/__init__.py | 10 ----- requirements-dev.txt | 3 +- uv.lock | 1 - 27 files changed, 110 insertions(+), 120 deletions(-) rename python/{letsql => xorq}/ibis_yaml/__init__.py (100%) rename python/{letsql => xorq}/ibis_yaml/compiler.py (95%) rename python/{letsql => xorq}/ibis_yaml/sql.py (93%) rename python/{letsql => xorq}/ibis_yaml/tests/__init__.py (100%) rename python/{letsql => xorq}/ibis_yaml/tests/conftest.py (98%) rename python/{letsql => xorq}/ibis_yaml/tests/test_arithmetic.py (98%) rename python/{letsql => xorq}/ibis_yaml/tests/test_basic.py (99%) rename python/{letsql => xorq}/ibis_yaml/tests/test_compiler.py (88%) rename python/{letsql => xorq}/ibis_yaml/tests/test_join_chain.py (98%) rename python/{letsql => xorq}/ibis_yaml/tests/test_letsql_ops.py (75%) rename python/{letsql => xorq}/ibis_yaml/tests/test_operations_boolean.py (99%) rename python/{letsql => xorq}/ibis_yaml/tests/test_operations_cast.py (98%) rename python/{letsql => xorq}/ibis_yaml/tests/test_operations_datetime.py (96%) rename python/{letsql => xorq}/ibis_yaml/tests/test_relations.py (98%) rename python/{letsql => xorq}/ibis_yaml/tests/test_selection.py (91%) rename python/{letsql => xorq}/ibis_yaml/tests/test_sql.py (82%) rename python/{letsql => xorq}/ibis_yaml/tests/test_string_ops.py (97%) rename python/{letsql => xorq}/ibis_yaml/tests/test_subquery.py (94%) rename python/{letsql => xorq}/ibis_yaml/tests/test_tpch.py (100%) rename python/{letsql => xorq}/ibis_yaml/tests/test_udf.py (93%) rename python/{letsql => xorq}/ibis_yaml/tests/test_window_functions.py (98%) rename python/{letsql => xorq}/ibis_yaml/translate.py (98%) rename python/{letsql => xorq}/ibis_yaml/utils.py (94%) diff --git a/examples/yaml_roundrip.py b/examples/yaml_roundrip.py index e83ee06a..c72a65fe 100644 --- a/examples/yaml_roundrip.py +++ b/examples/yaml_roundrip.py @@ -1,18 +1,18 @@ -import letsql as ls -from letsql.common.utils.defer_utils import deferred_read_parquet -from letsql.expr.relations import into_backend -from letsql.ibis_yaml.compiler import BuildManager +import xorq as xo +from xorq.common.utils.defer_utils import deferred_read_parquet +from xorq.expr.relations import into_backend +from xorq.ibis_yaml.compiler import BuildManager -pg = ls.postgres.connect_examples() -db = ls.duckdb.connect() +pg = xo.postgres.connect_examples() +db = xo.duckdb.connect() batting = pg.table("batting") -backend = ls.duckdb.connect() +backend = xo.duckdb.connect() awards_players = deferred_read_parquet( backend, - ls.config.options.pins.get_path("awards_players"), + xo.config.options.pins.get_path("awards_players"), table_name="award_players", ) left = batting.filter(batting.yearID == 2015) diff --git a/python/letsql/ibis_yaml/__init__.py b/python/xorq/ibis_yaml/__init__.py similarity index 100% rename from python/letsql/ibis_yaml/__init__.py rename to python/xorq/ibis_yaml/__init__.py diff --git a/python/letsql/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py similarity index 95% rename from python/letsql/ibis_yaml/compiler.py rename to python/xorq/ibis_yaml/compiler.py index 92d20ef0..16cb2671 100644 --- a/python/letsql/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -5,16 +5,16 @@ import dask import yaml -import letsql.vendor.ibis.expr.types as ir -from letsql.ibis_yaml.sql import generate_sql_plans -from letsql.ibis_yaml.translate import ( +import xorq.vendor.ibis.expr.types as ir +from xorq.ibis_yaml.sql import generate_sql_plans +from xorq.ibis_yaml.translate import ( SchemaRegistry, translate_from_yaml, translate_to_yaml, ) -from letsql.ibis_yaml.utils import find_all_backends, freeze -from letsql.vendor.ibis.backends import Profile -from letsql.vendor.ibis.common.collections import FrozenOrderedDict +from xorq.ibis_yaml.utils import find_all_backends, freeze +from xorq.vendor.ibis.backends import Profile +from xorq.vendor.ibis.common.collections import FrozenOrderedDict # is this the right way to handle this? or the right place diff --git a/python/letsql/ibis_yaml/sql.py b/python/xorq/ibis_yaml/sql.py similarity index 93% rename from python/letsql/ibis_yaml/sql.py rename to python/xorq/ibis_yaml/sql.py index b4a18b7a..0473d738 100644 --- a/python/letsql/ibis_yaml/sql.py +++ b/python/xorq/ibis_yaml/sql.py @@ -1,10 +1,10 @@ from typing import Any, Dict, TypedDict -import letsql.vendor.ibis as ibis -import letsql.vendor.ibis.expr.operations as ops -import letsql.vendor.ibis.expr.types as ir -from letsql.expr.relations import Read, RemoteTable -from letsql.ibis_yaml.utils import find_all_backends, find_relations +import xorq.vendor.ibis as ibis +import xorq.vendor.ibis.expr.operations as ops +import xorq.vendor.ibis.expr.types as ir +from xorq.expr.relations import Read, RemoteTable +from xorq.ibis_yaml.utils import find_all_backends, find_relations class QueryInfo(TypedDict): diff --git a/python/letsql/ibis_yaml/tests/__init__.py b/python/xorq/ibis_yaml/tests/__init__.py similarity index 100% rename from python/letsql/ibis_yaml/tests/__init__.py rename to python/xorq/ibis_yaml/tests/__init__.py diff --git a/python/letsql/ibis_yaml/tests/conftest.py b/python/xorq/ibis_yaml/tests/conftest.py similarity index 98% rename from python/letsql/ibis_yaml/tests/conftest.py rename to python/xorq/ibis_yaml/tests/conftest.py index 384380a5..057e4abf 100644 --- a/python/letsql/ibis_yaml/tests/conftest.py +++ b/python/xorq/ibis_yaml/tests/conftest.py @@ -2,8 +2,8 @@ import pytest -import letsql.vendor.ibis as ibis -import letsql.vendor.ibis.expr.datatypes as dt +import xorq.vendor.ibis as ibis +import xorq.vendor.ibis.expr.datatypes as dt # Fixtures from: https://github.com/ibis-project/ibis-substrait/blob/main/ibis_substrait/tests/compiler/test_tpch.py @@ -250,8 +250,8 @@ def tpc_h03(customer, orders, lineitem): @pytest.fixture def tpc_h04(orders, lineitem): - from letsql.vendor.ibis import _ - from letsql.vendor.ibis.expr.operations import ExistsSubquery + from xorq.vendor.ibis import _ + from xorq.vendor.ibis.expr.operations import ExistsSubquery lineitem_filtered = lineitem.filter( [ @@ -788,6 +788,6 @@ def build_dir(tmp_path_factory): @pytest.fixture def compiler(build_dir): - from letsql.ibis_yaml.compiler import YamlExpressionTranslator + from xorq.ibis_yaml.compiler import YamlExpressionTranslator return YamlExpressionTranslator() diff --git a/python/letsql/ibis_yaml/tests/test_arithmetic.py b/python/xorq/ibis_yaml/tests/test_arithmetic.py similarity index 98% rename from python/letsql/ibis_yaml/tests/test_arithmetic.py rename to python/xorq/ibis_yaml/tests/test_arithmetic.py index 37d0b4f7..6bbb7acb 100644 --- a/python/letsql/ibis_yaml/tests/test_arithmetic.py +++ b/python/xorq/ibis_yaml/tests/test_arithmetic.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_add(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_basic.py b/python/xorq/ibis_yaml/tests/test_basic.py similarity index 99% rename from python/letsql/ibis_yaml/tests/test_basic.py rename to python/xorq/ibis_yaml/tests/test_basic.py index ce21f768..263d4cd2 100644 --- a/python/letsql/ibis_yaml/tests/test_basic.py +++ b/python/xorq/ibis_yaml/tests/test_basic.py @@ -1,7 +1,7 @@ import datetime import decimal -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_unbound_table(t, compiler): diff --git a/python/letsql/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py similarity index 88% rename from python/letsql/ibis_yaml/tests/test_compiler.py rename to python/xorq/ibis_yaml/tests/test_compiler.py index 1f5e5cf5..54843553 100644 --- a/python/letsql/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -5,10 +5,10 @@ import pytest import yaml -import letsql as ls -from letsql.common.utils.defer_utils import deferred_read_parquet -from letsql.ibis_yaml.compiler import ArtifactStore, BuildManager -from letsql.vendor.ibis.common.collections import FrozenOrderedDict +import xorq as xo +from xorq.common.utils.defer_utils import deferred_read_parquet +from xorq.ibis_yaml.compiler import ArtifactStore, BuildManager +from xorq.vendor.ibis.common.collections import FrozenOrderedDict def test_build_manager_expr_hash(t, build_dir): @@ -62,7 +62,7 @@ def test_clean_frozen_dict_yaml(build_dir): @pytest.mark.xfail(reason="MemTable is not serializable") def test_ibis_compiler(t, build_dir): - t = ls.memtable({"a": [0, 1], "b": [0, 1]}) + t = xo.memtable({"a": [0, 1], "b": [0, 1]}) expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) compiler.compile_expr(expr) @@ -74,8 +74,8 @@ def test_ibis_compiler(t, build_dir): def test_ibis_compiler_parquet_reader(build_dir): - backend = ls.duckdb.connect() - parquet_path = ls.config.options.pins.get_path("awards_players") + backend = xo.duckdb.connect() + parquet_path = xo.config.options.pins.get_path("awards_players") awards_players = deferred_read_parquet( backend, parquet_path, table_name="award_players" ) @@ -90,10 +90,10 @@ def test_ibis_compiler_parquet_reader(build_dir): def test_compiler_sql(build_dir): - backend = ls.datafusion.connect() + backend = xo.datafusion.connect() awards_players = deferred_read_parquet( backend, - ls.config.options.pins.get_path("awards_players"), + xo.config.options.pins.get_path("awards_players"), table_name="awards_players", ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") @@ -126,7 +126,7 @@ def test_compiler_sql(build_dir): def test_ibis_compiler_expr_schema_ref(t, build_dir): - t = ls.memtable({"a": [0, 1], "b": [0, 1]}) + t = xo.memtable({"a": [0, 1], "b": [0, 1]}) expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) compiler.compile_expr(expr) diff --git a/python/letsql/ibis_yaml/tests/test_join_chain.py b/python/xorq/ibis_yaml/tests/test_join_chain.py similarity index 98% rename from python/letsql/ibis_yaml/tests/test_join_chain.py rename to python/xorq/ibis_yaml/tests/test_join_chain.py index 9da441c9..8c1e77f6 100644 --- a/python/letsql/ibis_yaml/tests/test_join_chain.py +++ b/python/xorq/ibis_yaml/tests/test_join_chain.py @@ -1,6 +1,6 @@ import pytest -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis @pytest.fixture diff --git a/python/letsql/ibis_yaml/tests/test_letsql_ops.py b/python/xorq/ibis_yaml/tests/test_letsql_ops.py similarity index 75% rename from python/letsql/ibis_yaml/tests/test_letsql_ops.py rename to python/xorq/ibis_yaml/tests/test_letsql_ops.py index 3ae918a8..9ff50f71 100644 --- a/python/letsql/ibis_yaml/tests/test_letsql_ops.py +++ b/python/xorq/ibis_yaml/tests/test_letsql_ops.py @@ -1,13 +1,13 @@ import pytest -import letsql as ls -from letsql import _ -from letsql.common.utils.defer_utils import ( +import xorq as xo +from xorq import _ +from xorq.common.utils.defer_utils import ( deferred_read_csv, deferred_read_parquet, ) -from letsql.expr.relations import into_backend -from letsql.ibis_yaml.compiler import YamlExpressionTranslator +from xorq.expr.relations import into_backend +from xorq.ibis_yaml.compiler import YamlExpressionTranslator @pytest.fixture(scope="session") @@ -18,7 +18,7 @@ def duckdb_path(tmp_path_factory): @pytest.fixture(scope="session") def prepare_duckdb_con(duckdb_path): - con = ls.duckdb.connect(duckdb_path) + con = xo.duckdb.connect(duckdb_path) con.profile_name = "my_duckdb" # patch con.raw_sql( @@ -63,9 +63,9 @@ def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): @pytest.mark.xfail(reason="MemTable is not serializable") def test_memtable(build_dir): - table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) + table = xo.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() - expr = table.mutate(new_val=2 * ls._.val) + expr = table.mutate(new_val=2 * xo._.val) profiles = {backend._profile.hash_name: backend} @@ -80,15 +80,15 @@ def test_memtable(build_dir): def test_into_backend(build_dir): - parquet_path = ls.config.options.pins.get_path("awards_players") - backend = ls.duckdb.connect() + parquet_path = xo.config.options.pins.get_path("awards_players") + backend = xo.duckdb.connect() table = deferred_read_parquet(backend, parquet_path, table_name="award_players") - expr = table.mutate(new_id=2 * ls._.playerID) + expr = table.mutate(new_id=2 * xo._.playerID) - con2 = ls.connect() - con3 = ls.connect() + con2 = xo.connect() + con3 = xo.connect() - expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * ls._.new_id) + expr = into_backend(expr, con2, "ls_mem").mutate(x=4 * xo._.new_id) expr = into_backend(expr, con3, "df_mem") profiles = { @@ -102,14 +102,14 @@ def test_into_backend(build_dir): yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) - assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) + assert xo.execute(expr).equals(xo.execute(roundtrip_expr)) @pytest.mark.xfail(reason="MemTable is not serializable") def test_memtable_cache(build_dir): - table = ls.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) + table = xo.memtable([(i, "val") for i in range(10)], columns=["key1", "val"]) backend = table._find_backend() - expr = table.mutate(new_val=2 * ls._.val).cache() + expr = table.mutate(new_val=2 * xo._.val).cache() backend1 = expr._find_backend() profiles = { @@ -122,13 +122,13 @@ def test_memtable_cache(build_dir): yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) - assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) + assert xo.execute(expr).equals(xo.execute(roundtrip_expr)) def test_deferred_read_csv(build_dir): csv_name = "iris" - csv_path = ls.options.pins.get_path(csv_name) - pd_con = ls.pandas.connect() + csv_path = xo.options.pins.get_path(csv_name) + pd_con = xo.pandas.connect() expr = deferred_read_csv(con=pd_con, path=csv_path, table_name=csv_name).filter( _.sepal_length > 6 ) @@ -138,4 +138,4 @@ def test_deferred_read_csv(build_dir): yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) - assert ls.execute(expr).equals(ls.execute(roundtrip_expr)) + assert xo.execute(expr).equals(xo.execute(roundtrip_expr)) diff --git a/python/letsql/ibis_yaml/tests/test_operations_boolean.py b/python/xorq/ibis_yaml/tests/test_operations_boolean.py similarity index 99% rename from python/letsql/ibis_yaml/tests/test_operations_boolean.py rename to python/xorq/ibis_yaml/tests/test_operations_boolean.py index 637131e5..4e991608 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_boolean.py +++ b/python/xorq/ibis_yaml/tests/test_operations_boolean.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_equals(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_operations_cast.py b/python/xorq/ibis_yaml/tests/test_operations_cast.py similarity index 98% rename from python/letsql/ibis_yaml/tests/test_operations_cast.py rename to python/xorq/ibis_yaml/tests/test_operations_cast.py index 03eb6219..4b032896 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_cast.py +++ b/python/xorq/ibis_yaml/tests/test_operations_cast.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_explicit_cast(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_operations_datetime.py b/python/xorq/ibis_yaml/tests/test_operations_datetime.py similarity index 96% rename from python/letsql/ibis_yaml/tests/test_operations_datetime.py rename to python/xorq/ibis_yaml/tests/test_operations_datetime.py index e1339252..8ca501c4 100644 --- a/python/letsql/ibis_yaml/tests/test_operations_datetime.py +++ b/python/xorq/ibis_yaml/tests/test_operations_datetime.py @@ -2,9 +2,9 @@ from datetime import datetime -import letsql.vendor.ibis as ibis -import letsql.vendor.ibis.expr.datatypes as dt -import letsql.vendor.ibis.expr.operations.temporal as tm +import xorq.vendor.ibis as ibis +import xorq.vendor.ibis.expr.datatypes as dt +import xorq.vendor.ibis.expr.operations.temporal as tm def test_date_extract(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_relations.py b/python/xorq/ibis_yaml/tests/test_relations.py similarity index 98% rename from python/letsql/ibis_yaml/tests/test_relations.py rename to python/xorq/ibis_yaml/tests/test_relations.py index 3c57cd2d..b45fb2bd 100644 --- a/python/letsql/ibis_yaml/tests/test_relations.py +++ b/python/xorq/ibis_yaml/tests/test_relations.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_filter(compiler, t): diff --git a/python/letsql/ibis_yaml/tests/test_selection.py b/python/xorq/ibis_yaml/tests/test_selection.py similarity index 91% rename from python/letsql/ibis_yaml/tests/test_selection.py rename to python/xorq/ibis_yaml/tests/test_selection.py index ba04e9ad..83b99b0e 100644 --- a/python/letsql/ibis_yaml/tests/test_selection.py +++ b/python/xorq/ibis_yaml/tests/test_selection.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_selection_on_view(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_sql.py b/python/xorq/ibis_yaml/tests/test_sql.py similarity index 82% rename from python/letsql/ibis_yaml/tests/test_sql.py rename to python/xorq/ibis_yaml/tests/test_sql.py index c1593e19..6ef9981c 100644 --- a/python/letsql/ibis_yaml/tests/test_sql.py +++ b/python/xorq/ibis_yaml/tests/test_sql.py @@ -1,13 +1,13 @@ -import letsql as ls -import letsql.vendor.ibis.expr.operations as ops -from letsql.expr.relations import RemoteTable, into_backend -from letsql.ibis_yaml.sql import find_remote_tables, generate_sql_plans +import xorq as xo +import xorq.vendor.ibis.expr.operations as ops +from xorq.expr.relations import RemoteTable, into_backend +from xorq.ibis_yaml.sql import find_remote_tables, generate_sql_plans def test_find_remote_tables_simple(): - db = ls.duckdb.connect() + db = xo.duckdb.connect() db.profile_name = "duckdb" - table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) + table = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) backend = table._find_backend() backend.profile_name = "duckdb" remote_expr = into_backend(table, db) @@ -21,13 +21,13 @@ def test_find_remote_tables_simple(): def test_find_remote_tables_nested(): - db1 = ls.duckdb.connect() + db1 = xo.duckdb.connect() db1.profile_name = "duckdb" - db2 = ls.datafusion.connect() + db2 = xo.datafusion.connect() db2.profile_name = "datafusion" - table1 = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val1"]) - table2 = ls.memtable([(1, "x"), (2, "y")], columns=["id", "val2"]) + table1 = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val1"]) + table2 = xo.memtable([(1, "x"), (2, "y")], columns=["id", "val2"]) remote1 = into_backend(table1, db1) remote2 = into_backend(table2, db2) @@ -41,14 +41,14 @@ def test_find_remote_tables_nested(): def test_find_remote_tables(): - pg = ls.postgres.connect_examples() + pg = xo.postgres.connect_examples() pg.profile_name = "postgres" - db = ls.duckdb.connect() + db = xo.duckdb.connect() db.profile_name = "duckdb" batting = pg.table("batting") awards_players = db.read_parquet( - ls.config.options.pins.get_path("awards_players"), + xo.config.options.pins.get_path("awards_players"), table_name="awards_players", ) @@ -80,10 +80,10 @@ def print_tree(node, level=0): def test_generate_sql_plans_simple(): - db = ls.duckdb.connect() + db = xo.duckdb.connect() db.profile_name = "duckdb" - table = ls.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) - expr = into_backend(table, db).filter(ls._.id > 1) + table = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) + expr = into_backend(table, db).filter(xo._.id > 1) plans = generate_sql_plans(expr) @@ -94,15 +94,15 @@ def test_generate_sql_plans_simple(): def test_generate_sql_plans_complex_example(): - pg = ls.postgres.connect_examples() + pg = xo.postgres.connect_examples() pg.profile_name = "postgres" - db = ls.duckdb.connect() + db = xo.duckdb.connect() db.profile_name = "duckdb" batting = pg.table("batting") awards_players = db.read_parquet( - ls.config.options.pins.get_path("awards_players"), + xo.config.options.pins.get_path("awards_players"), table_name="awards_players", ) diff --git a/python/letsql/ibis_yaml/tests/test_string_ops.py b/python/xorq/ibis_yaml/tests/test_string_ops.py similarity index 97% rename from python/letsql/ibis_yaml/tests/test_string_ops.py rename to python/xorq/ibis_yaml/tests/test_string_ops.py index 4af5d59e..0a2b3af3 100644 --- a/python/letsql/ibis_yaml/tests/test_string_ops.py +++ b/python/xorq/ibis_yaml/tests/test_string_ops.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_string_concat(compiler): diff --git a/python/letsql/ibis_yaml/tests/test_subquery.py b/python/xorq/ibis_yaml/tests/test_subquery.py similarity index 94% rename from python/letsql/ibis_yaml/tests/test_subquery.py rename to python/xorq/ibis_yaml/tests/test_subquery.py index 2b75fa08..0b9a720a 100644 --- a/python/letsql/ibis_yaml/tests/test_subquery.py +++ b/python/xorq/ibis_yaml/tests/test_subquery.py @@ -1,5 +1,5 @@ -import letsql.vendor.ibis as ibis -import letsql.vendor.ibis.expr.operations as ops +import xorq.vendor.ibis as ibis +import xorq.vendor.ibis.expr.operations as ops def test_scalar_subquery(compiler, t): diff --git a/python/letsql/ibis_yaml/tests/test_tpch.py b/python/xorq/ibis_yaml/tests/test_tpch.py similarity index 100% rename from python/letsql/ibis_yaml/tests/test_tpch.py rename to python/xorq/ibis_yaml/tests/test_tpch.py diff --git a/python/letsql/ibis_yaml/tests/test_udf.py b/python/xorq/ibis_yaml/tests/test_udf.py similarity index 93% rename from python/letsql/ibis_yaml/tests/test_udf.py rename to python/xorq/ibis_yaml/tests/test_udf.py index 5ba2cc56..70e01b7c 100644 --- a/python/letsql/ibis_yaml/tests/test_udf.py +++ b/python/xorq/ibis_yaml/tests/test_udf.py @@ -1,8 +1,8 @@ import pytest -import letsql.ibis_yaml -import letsql.ibis_yaml.utils -import letsql.vendor.ibis as ibis +import xorq.ibis_yaml +import xorq.ibis_yaml.utils +import xorq.vendor.ibis as ibis def test_built_in_udf_properties(compiler): @@ -66,6 +66,6 @@ def add_one(x: int) -> int: roundtrip_expr = compiler.from_yaml(yaml_dict) print(f"Original {expr}") print(f"Roundtrip {roundtrip_expr}") - letsql.ibis_yaml.utils.diff_ibis_exprs(expr, roundtrip_expr) + xorq.ibis_yaml.utils.diff_ibis_exprs(expr, roundtrip_expr) assert roundtrip_expr.equals(expr) diff --git a/python/letsql/ibis_yaml/tests/test_window_functions.py b/python/xorq/ibis_yaml/tests/test_window_functions.py similarity index 98% rename from python/letsql/ibis_yaml/tests/test_window_functions.py rename to python/xorq/ibis_yaml/tests/test_window_functions.py index aeb11066..ef2bea15 100644 --- a/python/letsql/ibis_yaml/tests/test_window_functions.py +++ b/python/xorq/ibis_yaml/tests/test_window_functions.py @@ -1,4 +1,4 @@ -import letsql.vendor.ibis as ibis +import xorq.vendor.ibis as ibis def test_window_function_roundtrip(compiler, t): diff --git a/python/letsql/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py similarity index 98% rename from python/letsql/ibis_yaml/translate.py rename to python/xorq/ibis_yaml/translate.py index a252e517..038f63a0 100644 --- a/python/letsql/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -5,21 +5,21 @@ import functools from typing import Any -import letsql.vendor.ibis as ibis -import letsql.vendor.ibis.expr.datatypes as dt -import letsql.vendor.ibis.expr.operations as ops -import letsql.vendor.ibis.expr.operations.temporal as tm -import letsql.vendor.ibis.expr.rules as rlz -import letsql.vendor.ibis.expr.types as ir -from letsql.expr.relations import CachedNode, Read, RemoteTable, into_backend -from letsql.ibis_yaml.utils import ( +import xorq.vendor.ibis as ibis +import xorq.vendor.ibis.expr.datatypes as dt +import xorq.vendor.ibis.expr.operations as ops +import xorq.vendor.ibis.expr.operations.temporal as tm +import xorq.vendor.ibis.expr.rules as rlz +import xorq.vendor.ibis.expr.types as ir +from xorq.expr.relations import CachedNode, Read, RemoteTable, into_backend +from xorq.ibis_yaml.utils import ( deserialize_udf_function, freeze, load_storage_from_yaml, serialize_udf_function, translate_storage, ) -from letsql.vendor.ibis.common.annotations import Argument +from xorq.vendor.ibis.common.annotations import Argument FROM_YAML_HANDLERS: dict[str, Any] = {} diff --git a/python/letsql/ibis_yaml/utils.py b/python/xorq/ibis_yaml/utils.py similarity index 94% rename from python/letsql/ibis_yaml/utils.py rename to python/xorq/ibis_yaml/utils.py index 2b856ff5..06873d42 100644 --- a/python/letsql/ibis_yaml/utils.py +++ b/python/xorq/ibis_yaml/utils.py @@ -4,13 +4,13 @@ import cloudpickle -import letsql.vendor.ibis.expr.operations as ops -import letsql.vendor.ibis.expr.types as ir -from letsql.common.caching import SourceStorage -from letsql.expr.relations import CachedNode, Read, RemoteTable -from letsql.vendor.ibis.backends import BaseBackend -from letsql.vendor.ibis.common.collections import FrozenOrderedDict -from letsql.vendor.ibis.expr.types.relations import Table +import xorq.vendor.ibis.expr.operations as ops +import xorq.vendor.ibis.expr.types as ir +from xorq.common.caching import SourceStorage +from xorq.expr.relations import CachedNode, Read, RemoteTable +from xorq.vendor.ibis.backends import BaseBackend +from xorq.vendor.ibis.common.collections import FrozenOrderedDict +from xorq.vendor.ibis.expr.types.relations import Table def serialize_udf_function(fn: callable) -> str: diff --git a/python/xorq/vendor/ibis/backends/__init__.py b/python/xorq/vendor/ibis/backends/__init__.py index 4d76e9bb..24f634e3 100644 --- a/python/xorq/vendor/ibis/backends/__init__.py +++ b/python/xorq/vendor/ibis/backends/__init__.py @@ -34,16 +34,6 @@ from xorq.common.utils.inspect_utils import get_arguments from xorq.vendor import ibis from xorq.vendor.ibis import util -import dask -import toolz -from attr import ( - field, - frozen, -) -from attr.validators import ( - instance_of, - optional, -) if TYPE_CHECKING: diff --git a/requirements-dev.txt b/requirements-dev.txt index 339c7efa..27552a4f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,6 +59,7 @@ googleapis-common-protos==1.68.0 ; python_full_version < '4.0' greenlet==3.1.1 ; (python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64') griffe==1.5.7 humanize==4.12.1 ; python_full_version < '4.0' +hypothesis==6.126.0 identify==2.6.7 idna==3.10 importlib-metadata==8.6.1 @@ -147,7 +148,7 @@ scipy==1.15.2 setuptools==75.8.0 ; sys_platform == 'darwin' six==1.17.0 snowflake-connector-python==3.13.2 ; python_full_version < '4.0' -sortedcontainers==2.4.0 ; python_full_version < '4.0' +sortedcontainers==2.4.0 sphobjinv==2.3.1.2 sqlalchemy==2.0.38 ; python_full_version < '4.0' sqlglot==25.20.2 diff --git a/uv.lock b/uv.lock index 0932705f..12a71050 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", From 2d5e6c8ea32a2d6db1874730092f2af97211d5bb Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Thu, 20 Feb 2025 18:51:45 -0500 Subject: [PATCH 25/45] wip: split out sql files and use walk_nodes to find opaque ops --- python/xorq/common/utils/graph_utils.py | 3 +- python/xorq/ibis_yaml/compiler.py | 5 +- python/xorq/ibis_yaml/sql.py | 129 +++++++++---------- python/xorq/ibis_yaml/tests/test_compiler.py | 17 ++- python/xorq/ibis_yaml/tests/test_sql.py | 22 +--- python/xorq/ibis_yaml/utils.py | 89 +------------ 6 files changed, 83 insertions(+), 182 deletions(-) diff --git a/python/xorq/common/utils/graph_utils.py b/python/xorq/common/utils/graph_utils.py index 6908c24d..7b370da0 100644 --- a/python/xorq/common/utils/graph_utils.py +++ b/python/xorq/common/utils/graph_utils.py @@ -28,7 +28,8 @@ def inner(rest, seen): rest.update(set(new).difference(seen)) return inner(rest, seen) - rest = process_node(expr.op()) + initial_op = expr.op() if hasattr(expr, "op") else expr + rest = process_node(initial_op) return inner(set(rest), set()) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 16cb2671..b096273b 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -6,13 +6,14 @@ import yaml import xorq.vendor.ibis.expr.types as ir +from xorq.common.utils.graph_utils import find_all_sources from xorq.ibis_yaml.sql import generate_sql_plans from xorq.ibis_yaml.translate import ( SchemaRegistry, translate_from_yaml, translate_to_yaml, ) -from xorq.ibis_yaml.utils import find_all_backends, freeze +from xorq.ibis_yaml.utils import freeze from xorq.vendor.ibis.backends import Profile from xorq.vendor.ibis.common.collections import FrozenOrderedDict @@ -163,7 +164,7 @@ def compile_expr(self, expr: ir.Expr) -> None: expr_hash = self.artifact_store.get_expr_hash(expr) current_path = self.artifact_store.get_build_path(expr_hash) - backends = find_all_backends(expr.op()) + backends = find_all_sources(expr) profiles = { backend._profile.hash_name: backend._profile.as_dict() for backend in backends diff --git a/python/xorq/ibis_yaml/sql.py b/python/xorq/ibis_yaml/sql.py index 0473d738..f2e6be1e 100644 --- a/python/xorq/ibis_yaml/sql.py +++ b/python/xorq/ibis_yaml/sql.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, TypedDict +from typing import Any, Dict, List, TypedDict import xorq.vendor.ibis as ibis import xorq.vendor.ibis.expr.operations as ops import xorq.vendor.ibis.expr.types as ir +from xorq.common.utils.graph_utils import find_all_sources, walk_nodes from xorq.expr.relations import Read, RemoteTable -from xorq.ibis_yaml.utils import find_all_backends, find_relations class QueryInfo(TypedDict): @@ -17,87 +17,84 @@ class SQLPlans(TypedDict): queries: Dict[str, QueryInfo] -def get_read_options(read_instance): - read_kwargs_list = [{k: v} for k, v in read_instance.read_kwargs] - - return { - "method_name": read_instance.method_name, - "name": read_instance.name, - "read_kwargs": read_kwargs_list, - } - - -def find_remote_tables(op) -> Dict[str, Dict[str, Any]]: - remote_tables = {} +def find_relations(expr: ir.Expr) -> List[str]: + node_types = (RemoteTable, Read, ops.DatabaseTable) + nodes = walk_nodes(node_types, expr) + relations = [] seen = set() + for node in nodes: + name = None + if isinstance(node, RemoteTable): + name = node.name + elif isinstance(node, Read): + name = node.make_unbound_dt().name + elif isinstance(node, ops.DatabaseTable): + name = node.name + if name and name not in seen: + seen.add(name) + relations.append(name) + return relations + + +def find_remote_tables(expr: ir.Expr) -> dict: + node_types = (RemoteTable, Read) + nodes = walk_nodes(node_types, expr) + remote_tables = {} - def traverse(node): - if node is None or id(node) in seen: - return - - seen.add(id(node)) - - if isinstance(node, ops.Node) and isinstance(node, RemoteTable): + for node in nodes: + if isinstance(node, RemoteTable): remote_expr = node.remote_expr - original_backend = find_all_backends(remote_expr)[ - 0 - ] # this was _find_backend before - - engine_name = original_backend.name - profile_name = original_backend._profile.hash_name - remote_tables[node.name] = { - "engine": engine_name, - "profile_name": profile_name, - "relations": find_relations(remote_expr), - "sql": ibis.to_sql(remote_expr), - "options": {}, - } - if isinstance(node, Read): - backend = node.source - if backend is not None: + backends = find_all_sources(node) + if len(backends) > 1: + backends = tuple( + x for x in backends if x != node.to_expr()._find_backend() + ) + for backend in backends: engine_name = backend.name profile_name = backend._profile.hash_name - remote_tables[node.make_unbound_dt().name] = { + key = f"{node.name}" + remote_tables[key] = { "engine": engine_name, "profile_name": profile_name, - "relations": [node.make_unbound_dt().name], - "sql": ibis.to_sql(node.make_unbound_dt().to_expr()), + "relations": find_relations(remote_expr), + "sql": ibis.to_sql(remote_expr).strip(), + "options": {}, + } + elif isinstance(node, Read): + backend = node.source + if backend is not None: + dt = node.make_unbound_dt() + key = dt.name + remote_tables[key] = { + "engine": backend.name, + "profile_name": backend._profile.hash_name, + "relations": [dt.name], + "sql": ibis.to_sql(dt.to_expr()).strip(), "options": get_read_options(node), } - - if isinstance(node, ops.Node): - for arg in node.args: - if isinstance(arg, ops.Node): - traverse(arg) - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, ops.Node): - traverse(item) - elif isinstance(arg, dict): - for v in arg.values(): - if isinstance(v, ops.Node): - traverse(v) - - traverse(op) return remote_tables -# TODO: rename to sqls -def generate_sql_plans(expr: ir.Expr) -> SQLPlans: - remote_tables = find_remote_tables(expr.op()) +def get_read_options(read_instance) -> Dict[str, Any]: + read_kwargs_list = [{k: v} for k, v in read_instance.read_kwargs] + return { + "method_name": read_instance.method_name, + "name": read_instance.name, + "read_kwargs": read_kwargs_list, + } + +def generate_sql_plans(expr: ir.Expr) -> SQLPlans: + remote_tables = find_remote_tables(expr) main_sql = ibis.to_sql(expr) backend = expr._find_backend() - engine_name = backend.name - profile_name = backend._profile.hash_name - plans: SQLPlans = { "queries": { "main": { - "engine": engine_name, - "profile_name": profile_name, - "relations": list(find_relations(expr)), + "engine": backend.name, + "profile_name": backend._profile.hash_name, + "relations": find_relations(expr), "sql": main_sql.strip(), "options": {}, } @@ -107,9 +104,9 @@ def generate_sql_plans(expr: ir.Expr) -> SQLPlans: for table_name, info in remote_tables.items(): plans["queries"][table_name] = { "engine": info["engine"], - "relations": info["relations"], "profile_name": info["profile_name"], - "sql": info["sql"].strip(), + "relations": info["relations"], + "sql": info["sql"], "options": info["options"], } diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 54843553..b2670d7f 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -82,7 +82,6 @@ def test_ibis_compiler_parquet_reader(build_dir): expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) compiler.compile_expr(expr) - print(dask.base.tokenize(expr)[:12]) expr_hash = "9a7d0b20d41a" roundtrip_expr = compiler.load_expr(expr_hash) @@ -113,14 +112,20 @@ def test_compiler_sql(build_dir): f" profile_name: {expr._find_backend()._profile.hash_name}\n" " relations:\n" " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" - ' sql: "SELECT\\n \\"t0\\".\\"playerID\\",\\n \\"t0\\".\\"awardID\\",\\n \\"t0\\".\\"tie\\"' - '\\\n ,\\n \\"t0\\".\\"notes\\"\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\"' - '\\\n \\ AS \\"t0\\"\\nWHERE\\n \\"t0\\".\\"lgID\\" = \'NL\'"\n' + " options: {}\n" + " sql_file: df34d95d62bc.sql\n" " awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f:\n" " engine: datafusion\n" - " relations: awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" " profile_name: a506210f56203e8f9b4a84ef73d95eaa\n" - ' sql: "SELECT\\n *\\nFROM \\"awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\\""\n' + " relations:\n" + " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" + " options:\n" + " method_name: read_parquet\n" + " name: awards_players\n" + " read_kwargs:\n" + " - path: /home/hussainsultan/.cache/pins-py/gs_d3037fb8920d01eb3b262ab08d52335c89ba62aa41299e5236f01807aa8b726d/awards_players/20240711T171119Z-886c4/awards_players.parquet\n" + " - table_name: awards_players\n" + " sql_file: c0907dab80b0.sql\n" ) assert sql_text == expected_result diff --git a/python/xorq/ibis_yaml/tests/test_sql.py b/python/xorq/ibis_yaml/tests/test_sql.py index 6ef9981c..e58745dc 100644 --- a/python/xorq/ibis_yaml/tests/test_sql.py +++ b/python/xorq/ibis_yaml/tests/test_sql.py @@ -1,18 +1,14 @@ import xorq as xo -import xorq.vendor.ibis.expr.operations as ops -from xorq.expr.relations import RemoteTable, into_backend +from xorq.expr.relations import into_backend from xorq.ibis_yaml.sql import find_remote_tables, generate_sql_plans def test_find_remote_tables_simple(): db = xo.duckdb.connect() - db.profile_name = "duckdb" table = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) - backend = table._find_backend() - backend.profile_name = "duckdb" remote_expr = into_backend(table, db) - remote_tables = find_remote_tables(remote_expr.op()) + remote_tables = find_remote_tables(remote_expr) assert len(remote_tables) == 1 table_name = next(iter(remote_tables)) @@ -58,21 +54,9 @@ def test_find_remote_tables(): ["yearID", "stint"] ] - def print_tree(node, level=0): - indent = " " * level - print(f"{indent}{type(node).__name__}") - if hasattr(node, "args"): - for arg in node.args: - if isinstance(arg, (ops.Node, RemoteTable)): - print_tree(arg, level + 1) - - print_tree(expr.op()) - remote_tables = find_remote_tables(expr.op()) - assert len(remote_tables) == 1, ( - f"Expected 1 remote table, found {len(remote_tables)}" - ) + assert len(remote_tables) == 1 first_table = next(iter(remote_tables.values())) assert "sql" in first_table, "SQL query missing from remote table info" diff --git a/python/xorq/ibis_yaml/utils.py b/python/xorq/ibis_yaml/utils.py index 06873d42..6cb02125 100644 --- a/python/xorq/ibis_yaml/utils.py +++ b/python/xorq/ibis_yaml/utils.py @@ -1,16 +1,11 @@ import base64 from collections.abc import Mapping, Sequence -from typing import Any, Dict, List, Tuple +from typing import Any, Dict import cloudpickle -import xorq.vendor.ibis.expr.operations as ops -import xorq.vendor.ibis.expr.types as ir from xorq.common.caching import SourceStorage -from xorq.expr.relations import CachedNode, Read, RemoteTable -from xorq.vendor.ibis.backends import BaseBackend from xorq.vendor.ibis.common.collections import FrozenOrderedDict -from xorq.vendor.ibis.expr.types.relations import Table def serialize_udf_function(fn: callable) -> str: @@ -152,85 +147,3 @@ def load_storage_from_yaml(storage_yaml: Dict, compiler: Any): return SourceStorage(source=source) else: raise NotImplementedError(f"Unknown storage type: {storage_yaml['type']}") - - -def find_all_backends(expr: ir.Expr) -> Tuple[BaseBackend, ...]: - backends = set() - seen = set() - - def traverse(node): - if node is None or id(node) in seen: - return - seen.add(id(node)) - - if isinstance(node, Table): - traverse(node.op()) - return - - if isinstance(node, Read): - backend = node.source - if backend is not None: - backends.add(backend) - elif isinstance(node, RemoteTable): - # this needs to habdle when a RemoteTable has Read op since the backend for the op is - # not the same as _find_backend() - backends.add(*find_all_backends(node.remote_expr)) - - elif isinstance(node, ops.DatabaseTable): - backends.add(node.source) - - elif isinstance(node, ops.SQLQueryResult): # caching_utils uses - backends.add(node.source) - - elif isinstance(node, CachedNode): - backends.add(node.source) - - if isinstance(node, ops.Node): - for arg in node.args: - if isinstance(arg, ops.Node): - traverse(arg) - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, ops.Node): - traverse(item) - elif isinstance(arg, dict): - for v in arg.values(): - if isinstance(v, ops.Node): - traverse(v) - - traverse(expr) - - return tuple(backends) - - -def find_relations(expr: ir.Expr) -> List[str]: - relations = [] - seen = set() - - def traverse(node): - if node is None or id(node) in seen: - return - seen.add(id(node)) - - if isinstance(node, ops.Node): - if isinstance(node, RemoteTable): - relations.append(node.name) - elif isinstance(node, Read): - relations.append(node.make_unbound_dt().name) - elif isinstance(node, ops.DatabaseTable): - relations.append(node.name) - - for arg in node.args: - if isinstance(arg, ops.Node): - traverse(arg) - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, ops.Node): - traverse(item) - elif isinstance(arg, dict): - for v in arg.values(): - if isinstance(v, ops.Node): - traverse(v) - - traverse(expr.op()) - return list(dict.fromkeys(relations)) From c5f97c610e530a9149a87a5bed77f34f1d821d0c Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 12:29:25 -0500 Subject: [PATCH 26/45] feat: split deferred reads into its own yaml --- python/xorq/ibis_yaml/compiler.py | 31 +++++++++++-- python/xorq/ibis_yaml/sql.py | 48 ++++++++++---------- python/xorq/ibis_yaml/tests/test_compiler.py | 42 +++++++++++++++++ python/xorq/ibis_yaml/tests/test_sql.py | 18 ++++---- 4 files changed, 102 insertions(+), 37 deletions(-) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index b096273b..1f604326 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -160,7 +160,22 @@ def _process_sql_plans( return updated_plans - def compile_expr(self, expr: ir.Expr) -> None: + def _process_deferred_reads( + self, deferred_reads: Dict[str, Any], expr_hash: str + ) -> Dict[str, Any]: + updated_reads = {"reads": {}} + + for read_name, read_info in deferred_reads["reads"].items(): + sql_filename = self._write_sql_file(read_info["sql"], expr_hash, read_name) + + updated_read_info = read_info.copy() + updated_read_info["sql_file"] = sql_filename + updated_read_info.pop("sql") + updated_reads["reads"][read_name] = updated_read_info + + return updated_reads + + def compile_expr(self, expr: ir.Expr) -> str: expr_hash = self.artifact_store.get_expr_hash(expr) current_path = self.artifact_store.get_build_path(expr_hash) @@ -175,15 +190,20 @@ def compile_expr(self, expr: ir.Expr) -> None: translator = YamlExpressionTranslator( profiles=profiles, current_path=current_path ) - # metadata.yaml (uv.lock, git commit version, version==xorq_internal_version, user, hostname, ip_address(host ip)) yaml_dict = translator.to_yaml(expr) self.artifact_store.save_yaml(yaml_dict, expr_hash, "expr.yaml") - self.artifact_store.save_yaml(profiles, expr_hash, "profiles.yaml") - sql_plans = generate_sql_plans(expr) + sql_plans, deferred_reads = generate_sql_plans(expr) + updated_sql_plans = self._process_sql_plans(sql_plans, expr_hash) self.artifact_store.save_yaml(updated_sql_plans, expr_hash, "sql.yaml") + + updated_deferred_reads = self._process_deferred_reads(deferred_reads, expr_hash) + self.artifact_store.save_yaml( + updated_deferred_reads, expr_hash, "deferred_reads.yaml" + ) + return expr_hash def load_expr(self, expr_hash: str) -> ir.Expr: @@ -209,3 +229,6 @@ def f(values): # TODO: maybe change name def load_sql_plans(self, expr_hash: str) -> Dict[str, Any]: return self.artifact_store.load_yaml(expr_hash, "sql.yaml") + + def load_deferred_reads(self, expr_hash: str) -> Dict[str, Any]: + return self.artifact_store.load_yaml(expr_hash, "deferred_reads.yaml") diff --git a/python/xorq/ibis_yaml/sql.py b/python/xorq/ibis_yaml/sql.py index f2e6be1e..73c4efdc 100644 --- a/python/xorq/ibis_yaml/sql.py +++ b/python/xorq/ibis_yaml/sql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Tuple, TypedDict import xorq.vendor.ibis as ibis import xorq.vendor.ibis.expr.operations as ops @@ -17,6 +17,10 @@ class SQLPlans(TypedDict): queries: Dict[str, QueryInfo] +class DeferredReadsPlan(TypedDict): + reads: Dict[str, QueryInfo] + + def find_relations(expr: ir.Expr) -> List[str]: node_types = (RemoteTable, Read, ops.DatabaseTable) nodes = walk_nodes(node_types, expr) @@ -36,10 +40,12 @@ def find_relations(expr: ir.Expr) -> List[str]: return relations -def find_remote_tables(expr: ir.Expr) -> dict: +def find_tables(expr: ir.Expr) -> Tuple[Dict[str, QueryInfo], Dict[str, QueryInfo]]: + remote_tables: Dict[str, QueryInfo] = {} + deferred_reads: Dict[str, QueryInfo] = {} + node_types = (RemoteTable, Read) nodes = walk_nodes(node_types, expr) - remote_tables = {} for node in nodes: if isinstance(node, RemoteTable): @@ -65,14 +71,14 @@ def find_remote_tables(expr: ir.Expr) -> dict: if backend is not None: dt = node.make_unbound_dt() key = dt.name - remote_tables[key] = { + deferred_reads[key] = { "engine": backend.name, "profile_name": backend._profile.hash_name, "relations": [dt.name], "sql": ibis.to_sql(dt.to_expr()).strip(), "options": get_read_options(node), } - return remote_tables + return remote_tables, deferred_reads def get_read_options(read_instance) -> Dict[str, Any]: @@ -84,30 +90,24 @@ def get_read_options(read_instance) -> Dict[str, Any]: } -def generate_sql_plans(expr: ir.Expr) -> SQLPlans: - remote_tables = find_remote_tables(expr) +def generate_sql_plans(expr: ir.Expr) -> Tuple[SQLPlans, DeferredReadsPlan]: + remote_tables, deferred_reads = find_tables(expr) main_sql = ibis.to_sql(expr) backend = expr._find_backend() - plans: SQLPlans = { - "queries": { - "main": { - "engine": backend.name, - "profile_name": backend._profile.hash_name, - "relations": find_relations(expr), - "sql": main_sql.strip(), - "options": {}, - } + queries: Dict[str, QueryInfo] = { + "main": { + "engine": backend.name, + "profile_name": backend._profile.hash_name, + "relations": find_relations(expr), + "sql": main_sql.strip(), + "options": {}, } } for table_name, info in remote_tables.items(): - plans["queries"][table_name] = { - "engine": info["engine"], - "profile_name": info["profile_name"], - "relations": info["relations"], - "sql": info["sql"], - "options": info["options"], - } + queries[table_name] = info - return plans + sql_plans: SQLPlans = {"queries": queries} + deferred_reads_plans: DeferredReadsPlan = {"reads": deferred_reads} + return sql_plans, deferred_reads_plans diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index b2670d7f..04080797 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -114,6 +114,29 @@ def test_compiler_sql(build_dir): " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" " options: {}\n" " sql_file: df34d95d62bc.sql\n" + ) + assert sql_text == expected_result + + +def test_deferred_reads_yaml(build_dir): + backend = xo.datafusion.connect() + awards_players = deferred_read_parquet( + backend, + xo.config.options.pins.get_path("awards_players"), + table_name="awards_players", + ) + expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + + compiler = BuildManager(build_dir) + compiler.compile_expr(expr) + expr_hash = "79d83e9c89ad" + _roundtrip_expr = compiler.load_expr(expr_hash) + assert os.path.exists(build_dir / expr_hash / "deferred_reads.yaml") + + sql_text = pathlib.Path(build_dir / expr_hash / "deferred_reads.yaml").read_text() + + expected_result = ( + "reads:\n" " awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f:\n" " engine: datafusion\n" " profile_name: a506210f56203e8f9b4a84ef73d95eaa\n" @@ -141,3 +164,22 @@ def test_ibis_compiler_expr_schema_ref(t, build_dir): yaml_dict = yaml.safe_load(f) assert yaml_dict["expression"]["schema_ref"] + + +def test_multi_engine_deferred_reads(build_dir): + con0 = xo.connect() + con1 = xo.connect() + con2 = xo.duckdb.connect() + con3 = xo.connect() + + awards_players = xo.examples.awards_players.fetch(con0).into_backend(con1) + batting = xo.examples.batting.fetch(con2).into_backend(con1) + expr = awards_players.join( + batting, predicates=["playerID", "yearID", "lgID"] + ).into_backend(con3)[lambda t: t.G == 1] + compiler = BuildManager(build_dir) + expr_hash = compiler.compile_expr(expr) + + roundtrip_expr = compiler.load_expr(expr_hash) + + assert expr.execute().equals(roundtrip_expr.execute()) diff --git a/python/xorq/ibis_yaml/tests/test_sql.py b/python/xorq/ibis_yaml/tests/test_sql.py index e58745dc..c0713702 100644 --- a/python/xorq/ibis_yaml/tests/test_sql.py +++ b/python/xorq/ibis_yaml/tests/test_sql.py @@ -1,14 +1,14 @@ import xorq as xo from xorq.expr.relations import into_backend -from xorq.ibis_yaml.sql import find_remote_tables, generate_sql_plans +from xorq.ibis_yaml.sql import find_tables, generate_sql_plans -def test_find_remote_tables_simple(): +def test_find_tables_simple(): db = xo.duckdb.connect() table = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) remote_expr = into_backend(table, db) - remote_tables = find_remote_tables(remote_expr) + remote_tables, _ = find_tables(remote_expr) assert len(remote_tables) == 1 table_name = next(iter(remote_tables)) @@ -16,7 +16,7 @@ def test_find_remote_tables_simple(): assert remote_tables[table_name]["engine"] == "duckdb" -def test_find_remote_tables_nested(): +def test_find_tables_nested(): db1 = xo.duckdb.connect() db1.profile_name = "duckdb" db2 = xo.datafusion.connect() @@ -29,14 +29,14 @@ def test_find_remote_tables_nested(): remote2 = into_backend(table2, db2) expr = remote1.join(remote2, "id") - remote_tables = find_remote_tables(expr.op()) + remote_tables, _ = find_tables(expr.op()) assert len(remote_tables) == 2 assert all(name.startswith("ibis_remote") for name in remote_tables) assert all("engine" in info and "sql" in info for info in remote_tables.values()) -def test_find_remote_tables(): +def test_find_tables(): pg = xo.postgres.connect_examples() pg.profile_name = "postgres" db = xo.duckdb.connect() @@ -54,7 +54,7 @@ def test_find_remote_tables(): ["yearID", "stint"] ] - remote_tables = find_remote_tables(expr.op()) + remote_tables, _ = find_tables(expr.op()) assert len(remote_tables) == 1 @@ -69,7 +69,7 @@ def test_generate_sql_plans_simple(): table = xo.memtable([(1, "a"), (2, "b")], columns=["id", "val"]) expr = into_backend(table, db).filter(xo._.id > 1) - plans = generate_sql_plans(expr) + plans, _deferred_reads = generate_sql_plans(expr) assert "queries" in plans assert "main" in plans["queries"] @@ -96,7 +96,7 @@ def test_generate_sql_plans_complex_example(): ["yearID", "stint"] ] - plans = generate_sql_plans(expr) + plans, _deferred_reads = generate_sql_plans(expr) assert "queries" in plans assert len(plans["queries"]) == 2 From 4a5685bf779651195bc6fe9c17d809712a9330bd Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 13:01:36 -0500 Subject: [PATCH 27/45] fix: Read op should have frozen read_kwargs --- python/xorq/ibis_yaml/tests/test_compiler.py | 29 ++++++++++++++++++-- python/xorq/ibis_yaml/translate.py | 4 +-- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 04080797..3dc5e751 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -174,9 +174,32 @@ def test_multi_engine_deferred_reads(build_dir): awards_players = xo.examples.awards_players.fetch(con0).into_backend(con1) batting = xo.examples.batting.fetch(con2).into_backend(con1) - expr = awards_players.join( - batting, predicates=["playerID", "yearID", "lgID"] - ).into_backend(con3)[lambda t: t.G == 1] + expr = ( + awards_players.join(batting, predicates=["playerID", "yearID", "lgID"]) + .into_backend(con3) + .filter(xo._.G == 1) + ) + compiler = BuildManager(build_dir) + expr_hash = compiler.compile_expr(expr) + + roundtrip_expr = compiler.load_expr(expr_hash) + + assert expr.execute().equals(roundtrip_expr.execute()) + + +def test_multi_engine_with_caching(build_dir): + con0 = xo.connect() + con1 = xo.connect() + con2 = xo.duckdb.connect() + con3 = xo.connect() + + awards_players = xo.examples.awards_players.fetch(con0).into_backend(con1).cache() + batting = xo.examples.batting.fetch(con2).into_backend(con1).cache() + expr = ( + awards_players.join(batting, predicates=["playerID", "yearID", "lgID"]) + .into_backend(con3) + .filter(xo._.G == 1) + ) compiler = BuildManager(build_dir) expr_hash = compiler.compile_expr(expr) diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index 038f63a0..dbaaefbc 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -327,7 +327,7 @@ def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: "parent": translate_to_yaml(op.parent, compiler), "source": op.source._profile.hash_name, "storage": translate_storage(op.storage, compiler), - "values": dict(op.values), + "values": freeze(op.values), } ) @@ -415,7 +415,7 @@ def _read_to_yaml(op: Read, compiler: Any) -> dict: "name": op.name, "schema_ref": schema_id, "profile": profile_hash_name, - "read_kwargs": dict(op.read_kwargs) if op.read_kwargs else {}, + "read_kwargs": freeze(op.read_kwargs if op.read_kwargs else {}), } ) From bdd1582863871893ec55ff5da401a3f837f5cfcc Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 14:32:53 -0500 Subject: [PATCH 28/45] refactor: refactor YamlExpressionCompiler so it does not have any state --- python/xorq/ibis_yaml/compiler.py | 71 +++++++++++-------- python/xorq/ibis_yaml/tests/test_compiler.py | 2 +- .../xorq/ibis_yaml/tests/test_letsql_ops.py | 29 ++++---- python/xorq/ibis_yaml/translate.py | 6 ++ 4 files changed, 60 insertions(+), 48 deletions(-) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 1f604326..eb38a4cb 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, Dict +import attr import dask import yaml @@ -95,39 +96,57 @@ def get_build_path(self, expr_hash: str) -> pathlib.Path: return self.ensure_dir(expr_hash) +@attr.s(frozen=True, auto_attribs=True) +class TranslationContext: + schema_registry: SchemaRegistry = attr.ib(factory=SchemaRegistry) + profiles: FrozenOrderedDict = attr.ib(factory=FrozenOrderedDict) + definitions: FrozenOrderedDict = attr.ib(factory=lambda: freeze({"schemas": {}})) + + def update_definitions(self, new_definitions: FrozenOrderedDict): + return attr.evolve(self, definitions=new_definitions) + + class YamlExpressionTranslator: - def __init__( + def __init__(self): + pass + + def to_yaml( self, - schema_registry: SchemaRegistry = None, - profiles: Dict = None, - current_path: Path = None, - ): - self.schema_registry = schema_registry or SchemaRegistry() - self.definitions = {} - self.profiles = profiles or {} - self.current_path = current_path - - def to_yaml(self, expr: ir.Expr) -> Dict[str, Any]: - schema_ref = self._register_expr_schema(expr) - expr_dict = translate_to_yaml(expr, self) + expr: ir.Expr, + profiles=None, + ) -> Dict[str, Any]: + context = TranslationContext( + schema_registry=SchemaRegistry(), + profiles=freeze(profiles or {}), + ) + schema_ref = context.schema_registry._register_expr_schema(expr) + expr_dict = translate_to_yaml(expr, context) expr_dict = freeze({**dict(expr_dict), "schema_ref": schema_ref}) return freeze( { - "definitions": {"schemas": self.schema_registry.schemas}, + "definitions": {"schemas": context.schema_registry.schemas}, "expression": expr_dict, } ) - def from_yaml(self, yaml_dict: Dict[str, Any]) -> ir.Expr: - self.definitions = yaml_dict.get("definitions", {}) + def from_yaml( + self, + yaml_dict: Dict[str, Any], + profiles=None, + ) -> ir.Expr: + context = TranslationContext( + schema_registry=SchemaRegistry(), + profiles=freeze(profiles or {}), + ) + context = context.update_definitions(freeze(yaml_dict.get("definitions", {}))) expr_dict = freeze(yaml_dict["expression"]) - return translate_from_yaml(expr_dict, self) + return translate_from_yaml(expr_dict, freeze(context)) def _register_expr_schema(self, expr: ir.Expr) -> str: if hasattr(expr, "schema"): schema = expr.schema() - return self.schema_registry.register_schema(schema) + return self.context.schema_registry.register_schema(schema) return None @@ -177,7 +196,6 @@ def _process_deferred_reads( def compile_expr(self, expr: ir.Expr) -> str: expr_hash = self.artifact_store.get_expr_hash(expr) - current_path = self.artifact_store.get_build_path(expr_hash) backends = find_all_sources(expr) profiles = { @@ -185,12 +203,8 @@ def compile_expr(self, expr: ir.Expr) -> str: for backend in backends } - print(profiles) - - translator = YamlExpressionTranslator( - profiles=profiles, current_path=current_path - ) - yaml_dict = translator.to_yaml(expr) + translator = YamlExpressionTranslator() + yaml_dict = translator.to_yaml(expr, profiles) self.artifact_store.save_yaml(yaml_dict, expr_hash, "expr.yaml") self.artifact_store.save_yaml(profiles, expr_hash, "profiles.yaml") @@ -207,7 +221,6 @@ def compile_expr(self, expr: ir.Expr) -> str: return expr_hash def load_expr(self, expr_hash: str) -> ir.Expr: - build_path = self.artifact_store.get_build_path(expr_hash) profiles_dict = self.artifact_store.load_yaml(expr_hash, "profiles.yaml") def f(values): @@ -219,12 +232,10 @@ def f(values): profile: Profile(**f(values)).get_con() for profile, values in profiles_dict.items() } - translator = YamlExpressionTranslator( - current_path=build_path, profiles=profiles - ) + translator = YamlExpressionTranslator() yaml_dict = self.artifact_store.load_yaml(expr_hash, "expr.yaml") - return translator.from_yaml(yaml_dict) + return translator.from_yaml(yaml_dict, profiles=profiles) # TODO: maybe change name def load_sql_plans(self, expr_hash: str) -> Dict[str, Any]: diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 3dc5e751..352fef30 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -139,7 +139,7 @@ def test_deferred_reads_yaml(build_dir): "reads:\n" " awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f:\n" " engine: datafusion\n" - " profile_name: a506210f56203e8f9b4a84ef73d95eaa\n" + " profile_name: 30174be6bf62a829d7e62af391fc53b2\n" " relations:\n" " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" " options:\n" diff --git a/python/xorq/ibis_yaml/tests/test_letsql_ops.py b/python/xorq/ibis_yaml/tests/test_letsql_ops.py index 9ff50f71..0637a49f 100644 --- a/python/xorq/ibis_yaml/tests/test_letsql_ops.py +++ b/python/xorq/ibis_yaml/tests/test_letsql_ops.py @@ -41,19 +41,14 @@ def prepare_duckdb_con(duckdb_path): def test_duckdb_database_table_roundtrip(prepare_duckdb_con, build_dir): con = prepare_duckdb_con - profiles = {con._profile.hash_name: con} - table_expr = con.table("mytable") expr1 = table_expr.mutate(new_val=(table_expr.val + "_extra")) - compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) - - yaml_dict = compiler.to_yaml(expr1) + compiler = YamlExpressionTranslator() - print("Serialized YAML:\n", yaml_dict) - - roundtrip_expr = compiler.from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr1, profiles) + roundtrip_expr = compiler.from_yaml(yaml_dict, profiles) df_original = expr1.execute() df_roundtrip = roundtrip_expr.execute() @@ -69,9 +64,9 @@ def test_memtable(build_dir): profiles = {backend._profile.hash_name: backend} - compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) + compiler = YamlExpressionTranslator() - yaml_dict = compiler.to_yaml(expr) + yaml_dict = compiler.to_yaml(expr, profiles, build_dir) roundtrip_expr = compiler.from_yaml(yaml_dict) expr.equals(roundtrip_expr) @@ -97,10 +92,10 @@ def test_into_backend(build_dir): con3._profile.hash_name: con3, } - compiler = YamlExpressionTranslator(current_path=build_dir, profiles=profiles) + compiler = YamlExpressionTranslator() - yaml_dict = compiler.to_yaml(expr) - roundtrip_expr = compiler.from_yaml(yaml_dict) + yaml_dict = compiler.to_yaml(expr, profiles) + roundtrip_expr = compiler.from_yaml(yaml_dict, profiles) assert xo.execute(expr).equals(xo.execute(roundtrip_expr)) @@ -117,7 +112,7 @@ def test_memtable_cache(build_dir): backend1._profile.hash_name: backend1, } - compiler = YamlExpressionTranslator(profiles=profiles, current_path=build_dir) + compiler = YamlExpressionTranslator(profiles=profiles) yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) @@ -134,8 +129,8 @@ def test_deferred_read_csv(build_dir): ) profiles = {pd_con._profile.hash_name: pd_con} - compiler = YamlExpressionTranslator(profiles=profiles, current_path=build_dir) - yaml_dict = compiler.to_yaml(expr) - roundtrip_expr = compiler.from_yaml(yaml_dict) + compiler = YamlExpressionTranslator() + yaml_dict = compiler.to_yaml(expr, profiles) + roundtrip_expr = compiler.from_yaml(yaml_dict, profiles) assert xo.execute(expr).equals(xo.execute(roundtrip_expr)) diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index dbaaefbc..721eff61 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -44,6 +44,12 @@ def register_schema(self, schema): self.counter += 1 return schema_id + def _register_expr_schema(self, expr: ir.Expr) -> str: + if hasattr(expr, "schema"): + schema = expr.schema() + return self.register_schema(schema) + return None + def register_from_yaml_handler(*op_names: str): def decorator(func): From 359d239b76aed57eec4dfadb9de8937274e6b83c Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 16:15:38 -0500 Subject: [PATCH 29/45] feat: add BuildConfig feat: disable yaml alias/achors refactor: remove unnecessary `_register_expr_schema` in YamlExpressionTranslator refactor: remove unnecessary definitions checks --- python/xorq/ibis_yaml/compiler.py | 16 +++++++------- python/xorq/ibis_yaml/config.py | 19 ++++++++++++++++ python/xorq/ibis_yaml/tests/test_compiler.py | 7 ++---- python/xorq/ibis_yaml/tests/test_config.py | 23 ++++++++++++++++++++ python/xorq/ibis_yaml/translate.py | 12 ---------- 5 files changed, 52 insertions(+), 25 deletions(-) create mode 100644 python/xorq/ibis_yaml/config.py create mode 100644 python/xorq/ibis_yaml/tests/test_config.py diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index eb38a4cb..5a88dcd3 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -8,6 +8,7 @@ import xorq.vendor.ibis.expr.types as ir from xorq.common.utils.graph_utils import find_all_sources +from xorq.ibis_yaml.config import config from xorq.ibis_yaml.sql import generate_sql_plans from xorq.ibis_yaml.translate import ( SchemaRegistry, @@ -24,6 +25,9 @@ class CleanDictYAMLDumper(yaml.SafeDumper): def represent_frozenordereddict(self, data): return self.represent_dict(dict(data)) + def ignore_aliases(self, data): + return True + CleanDictYAMLDumper.add_representer( FrozenOrderedDict, CleanDictYAMLDumper.represent_frozenordereddict @@ -84,7 +88,8 @@ def exists(self, *path_parts) -> bool: def get_expr_hash(self, expr) -> str: expr_hash = dask.base.tokenize(expr) - return expr_hash[:12] # TODO: make length of hash as a config + hash_length = config.hash_length + return expr_hash[:hash_length] # TODO: make length of hash as a config def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash, filename) -> pathlib.Path: return self.write_yaml(yaml_dict, expr_hash, filename) @@ -143,12 +148,6 @@ def from_yaml( expr_dict = freeze(yaml_dict["expression"]) return translate_from_yaml(expr_dict, freeze(context)) - def _register_expr_schema(self, expr: ir.Expr) -> str: - if hasattr(expr, "schema"): - schema = expr.schema() - return self.context.schema_registry.register_schema(schema) - return None - class BuildManager: def __init__(self, build_dir: pathlib.Path): @@ -156,7 +155,8 @@ def __init__(self, build_dir: pathlib.Path): self.profiles = {} def _write_sql_file(self, sql: str, expr_hash: str, query_name: str) -> str: - sql_hash = dask.base.tokenize(sql)[:12] + hash_length = config.hash_length + sql_hash = dask.base.tokenize(sql)[:hash_length] filename = f"{sql_hash}.sql" sql_path = self.artifact_store.get_build_path(expr_hash) / filename sql_path.write_text(sql) diff --git a/python/xorq/ibis_yaml/config.py b/python/xorq/ibis_yaml/config.py new file mode 100644 index 00000000..1d57962a --- /dev/null +++ b/python/xorq/ibis_yaml/config.py @@ -0,0 +1,19 @@ +import pathlib + +from xorq.vendor.ibis.config import Config + + +class BuildConfig(Config): + hash_length: int = 12 + _build_path: pathlib.Path = pathlib.Path.cwd() / "builds" + + @property + def build_path(self) -> pathlib.Path: + return self._build_path + + @build_path.setter + def build_path(self, value: pathlib.Path): + self._build_path = value + + +config = BuildConfig() diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 352fef30..fe81b212 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -1,7 +1,6 @@ import os import pathlib -import dask import pytest import yaml @@ -65,8 +64,7 @@ def test_ibis_compiler(t, build_dir): t = xo.memtable({"a": [0, 1], "b": [0, 1]}) expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) - compiler.compile_expr(expr) - expr_hash = dask.base.tokenize(expr)[:12] + expr_hash = compiler.compile_expr(expr) roundtrip_expr = compiler.load_expr(expr_hash) @@ -157,8 +155,7 @@ def test_ibis_compiler_expr_schema_ref(t, build_dir): t = xo.memtable({"a": [0, 1], "b": [0, 1]}) expr = t.filter(t.a == 1).drop("b") compiler = BuildManager(build_dir) - compiler.compile_expr(expr) - expr_hash = dask.base.tokenize(expr)[:12] + expr_hash = compiler.compile_expr(expr) with open(build_dir / expr_hash / "expr.yaml") as f: yaml_dict = yaml.safe_load(f) diff --git a/python/xorq/ibis_yaml/tests/test_config.py b/python/xorq/ibis_yaml/tests/test_config.py new file mode 100644 index 00000000..3f7193e1 --- /dev/null +++ b/python/xorq/ibis_yaml/tests/test_config.py @@ -0,0 +1,23 @@ +import pathlib + +from xorq.ibis_yaml.config import ( + BuildConfig, # Replace 'your_module' with the appropriate module name +) + + +def test_default_hash_length(): + config = BuildConfig() + assert config.hash_length == 12, "Default hash_length should be 12" + + +def test_default_build_path(): + config = BuildConfig() + expected_path = pathlib.Path.cwd() / "builds" + assert config.build_path == expected_path + + +def test_set_build_path(): + config = BuildConfig() + new_path = pathlib.Path.cwd() / "custom_builds" + config.build_path = new_path + assert config.build_path == new_path diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index 721eff61..a6cbe06f 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -274,8 +274,6 @@ def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: @register_from_yaml_handler("UnboundTable") def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: table_name = yaml_dict["name"] - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") schema_ref = yaml_dict["schema_ref"] try: @@ -340,9 +338,6 @@ def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: @register_from_yaml_handler("CachedNode") def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - schema_ref = yaml_dict["schema_ref"] try: schema_def = compiler.definitions["schemas"][schema_ref] @@ -428,9 +423,6 @@ def _read_to_yaml(op: Read, compiler: Any) -> dict: @register_from_yaml_handler("Read") def _read_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - if not hasattr(compiler, "definitions"): - raise ValueError("Compiler missing definitions with schemas") - schema_ref = yaml_dict["schema_ref"] schema_def = compiler.definitions["schemas"][schema_ref] @@ -1293,10 +1285,6 @@ def _string_sql_like_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: def _type_from_yaml(yaml_dict: dict) -> dt.DataType: - if isinstance(yaml_dict, str): - raise ValueError( - f"Unexpected string value '{yaml_dict}' - type definitions should be dictionaries" - ) type_name = yaml_dict["name"] base_type = REVERSE_TYPE_REGISTRY.get(type_name) if base_type is None: From a66fb59d170c0f17a393b2a51b1fd6fd536463cb Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 19:03:59 -0500 Subject: [PATCH 30/45] fix: handle Namespace in tables --- python/xorq/ibis_yaml/tests/test_basic.py | 7 ++++ python/xorq/ibis_yaml/translate.py | 41 +++++++++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/python/xorq/ibis_yaml/tests/test_basic.py b/python/xorq/ibis_yaml/tests/test_basic.py index 263d4cd2..3c87817d 100644 --- a/python/xorq/ibis_yaml/tests/test_basic.py +++ b/python/xorq/ibis_yaml/tests/test_basic.py @@ -165,3 +165,10 @@ def test_aggregation_roundtrip(t, compiler): yaml_dict = compiler.to_yaml(expr) roundtrip_expr = compiler.from_yaml(yaml_dict) assert roundtrip_expr.schema() == expr.schema() + + +def test_table_perserves_namespace(compiler): + expr = ibis.table({"b": "int64"}, name="t2", database="db", catalog="catalog") + yaml_dict = compiler.to_yaml(expr) + roundtrip_expr = compiler.from_yaml(yaml_dict) + assert roundtrip_expr.op().namespace == expr.op().namespace diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index a6cbe06f..0c754cf5 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -20,6 +20,7 @@ translate_storage, ) from xorq.vendor.ibis.common.annotations import Argument +from xorq.vendor.ibis.expr.operations.relations import Namespace FROM_YAML_HANDLERS: dict[str, Any] = {} @@ -262,11 +263,18 @@ def _base_op_to_yaml(op: ops.Node, compiler: Any) -> dict: @translate_to_yaml.register(ops.UnboundTable) def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: schema_id = compiler.schema_registry.register_schema(op.schema) + namespace_dict = freeze( + { + "catalog": op.namespace.catalog, + "database": op.namespace.database, + } + ) return freeze( { "op": "UnboundTable", "name": op.name, "schema_ref": schema_id, + "namespace": namespace_dict, } ) @@ -280,17 +288,26 @@ def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: schema_def = compiler.definitions["schemas"][schema_ref] except KeyError: raise ValueError(f"Schema {schema_ref} not found in definitions") - + namespace_dict = yaml_dict.get("namespace", {}) + catalog = namespace_dict.get("catalog") + database = namespace_dict.get("database") schema = { name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() } - return ibis.table(schema, name=table_name) + # TODO: use UnboundTable node to construct instead of builder API + return ibis.table(schema, name=table_name, catalog=catalog, database=database) @translate_to_yaml.register(ops.DatabaseTable) def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: profile_name = op.source._profile.hash_name schema_id = compiler.schema_registry.register_schema(op.schema) + namespace_dict = freeze( + { + "catalog": op.namespace.catalog, + "database": op.namespace.database, + } + ) return freeze( { @@ -298,6 +315,7 @@ def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: "table": op.name, "schema_ref": schema_id, "profile": profile_name, + "namespace": namespace_dict, } ) @@ -306,17 +324,28 @@ def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: profile_name = yaml_dict.get("profile") table_name = yaml_dict.get("table") + namespace_dict = yaml_dict.get("namespace", {}) + catalog = namespace_dict.get("catalog") + database = namespace_dict.get("database") # we should validate that schema is the same schema_ref = yaml_dict.get("schema_ref") - - if not all([profile_name, table_name, schema_ref]): - raise ValueError("Missing required information in YAML for DatabaseTable.") + schema_def = compiler.definitions["schemas"][schema_ref] + fields = [] + for name, dtype_yaml in schema_def.items(): + dtype = _type_from_yaml(dtype_yaml) + fields.append((name, dtype)) + schema = ibis.Schema.from_tuples(fields) try: con = compiler.profiles[profile_name] except KeyError: raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") - return con.table(table_name) + return ops.DatabaseTable( + schema=schema, + source=con, + name=table_name, + namespace=Namespace(catalog=catalog, database=database), + ).to_expr() @translate_to_yaml.register(CachedNode) From bdb4067f6ae23c9fd288b955a903047b36f44c14 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Fri, 21 Feb 2025 19:26:54 -0500 Subject: [PATCH 31/45] fix: remove checks if values present --- python/xorq/ibis_yaml/translate.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index 0c754cf5..ad6d9dee 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -446,6 +446,7 @@ def _read_to_yaml(op: Read, compiler: Any) -> dict: "schema_ref": schema_id, "profile": profile_hash_name, "read_kwargs": freeze(op.read_kwargs if op.read_kwargs else {}), + "values": freeze(op.values), } ) @@ -697,10 +698,9 @@ def _join_to_yaml(op: ops.JoinChain, compiler: Any) -> dict: for link in op.rest ], } - if hasattr(op, "values") and op.values: - result["values"] = { - name: translate_to_yaml(val, compiler) for name, val in op.values.items() - } + result["values"] = { + name: translate_to_yaml(val, compiler) for name, val in op.values.items() + } return freeze(result) @@ -716,12 +716,11 @@ def _join_chain_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: ] result = result.join(table, predicates, how=join["how"]) - if "values" in yaml_dict: - values = { - name: translate_from_yaml(val, compiler) - for name, val in yaml_dict["values"].items() - } - result = result.select(values) + values = { + name: translate_from_yaml(val, compiler) + for name, val in yaml_dict["values"].items() + } + result = result.select(values) return result From 981eadde64a93c94ff8970f0b139828be6839ce3 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 22 Feb 2025 08:16:20 -0500 Subject: [PATCH 32/45] fix: kwargs_tuple in profile should be a key,value pair in yaml --- python/xorq/ibis_yaml/compiler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 5a88dcd3..37eaf3e8 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -199,7 +199,10 @@ def compile_expr(self, expr: ir.Expr) -> str: backends = find_all_sources(expr) profiles = { - backend._profile.hash_name: backend._profile.as_dict() + backend._profile.hash_name: { + **backend._profile.as_dict(), + "kwargs_tuple": dict(backend._profile.as_dict()["kwargs_tuple"]), + } for backend in backends } @@ -225,7 +228,10 @@ def load_expr(self, expr_hash: str) -> ir.Expr: def f(values): dct = dict(values) - dct["kwargs_tuple"] = tuple(map(tuple, dct["kwargs_tuple"])) + if isinstance(dct["kwargs_tuple"], dict): + dct["kwargs_tuple"] = tuple(dct["kwargs_tuple"].items()) + else: + dct["kwargs_tuple"] = tuple(map(tuple, dct["kwargs_tuple"])) return dct profiles = { From 44ec5ed966a17cbbf74136854ff746011fd1f616 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 22 Feb 2025 09:20:46 -0500 Subject: [PATCH 33/45] refactor: pull out udf.py and make common module --- python/xorq/ibis_yaml/common.py | 64 +++++++++++++ python/xorq/ibis_yaml/compiler.py | 7 +- python/xorq/ibis_yaml/translate.py | 139 ++--------------------------- python/xorq/ibis_yaml/udf.py | 97 ++++++++++++++++++++ python/xorq/ibis_yaml/utils.py | 14 --- 5 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 python/xorq/ibis_yaml/common.py create mode 100644 python/xorq/ibis_yaml/udf.py diff --git a/python/xorq/ibis_yaml/common.py b/python/xorq/ibis_yaml/common.py new file mode 100644 index 00000000..c566c662 --- /dev/null +++ b/python/xorq/ibis_yaml/common.py @@ -0,0 +1,64 @@ +import functools +from typing import Any + +import xorq.vendor.ibis.expr.datatypes as dt +import xorq.vendor.ibis.expr.types as ir +from xorq.ibis_yaml.utils import freeze + + +FROM_YAML_HANDLERS: dict[str, Any] = {} + + +def register_from_yaml_handler(*op_names: str): + def decorator(func): + for name in op_names: + FROM_YAML_HANDLERS[name] = func + return func + + return decorator + + +@functools.cache +@functools.singledispatch +def translate_from_yaml(yaml_dict: dict, compiler: Any) -> Any: + op_type = yaml_dict["op"] + if op_type not in FROM_YAML_HANDLERS: + raise NotImplementedError(f"No handler for operation {op_type}") + return FROM_YAML_HANDLERS[op_type](yaml_dict, compiler) + + +@functools.cache +@functools.singledispatch +def translate_to_yaml(op: Any, compiler: Any) -> dict: + raise NotImplementedError(f"No translation rule for {type(op)}") + + +@functools.singledispatch +def _translate_type(dtype: dt.DataType) -> dict: + return freeze({"name": type(dtype).__name__, "nullable": dtype.nullable}) + + +class SchemaRegistry: + def __init__(self): + self.schemas = {} + self.counter = 0 + + def register_schema(self, schema): + frozen_schema = freeze( + {name: _translate_type(dtype) for name, dtype in schema.items()} + ) + + for schema_id, existing_schema in self.schemas.items(): + if existing_schema == frozen_schema: + return schema_id + + schema_id = f"schema_{self.counter}" + self.schemas[schema_id] = frozen_schema + self.counter += 1 + return schema_id + + def _register_expr_schema(self, expr: ir.Expr) -> str: + if hasattr(expr, "schema"): + schema = expr.schema() + return self.register_schema(schema) + return None diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 37eaf3e8..e4192ea4 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -8,10 +8,10 @@ import xorq.vendor.ibis.expr.types as ir from xorq.common.utils.graph_utils import find_all_sources +from xorq.ibis_yaml.common import SchemaRegistry from xorq.ibis_yaml.config import config from xorq.ibis_yaml.sql import generate_sql_plans from xorq.ibis_yaml.translate import ( - SchemaRegistry, translate_from_yaml, translate_to_yaml, ) @@ -20,7 +20,6 @@ from xorq.vendor.ibis.common.collections import FrozenOrderedDict -# is this the right way to handle this? or the right place class CleanDictYAMLDumper(yaml.SafeDumper): def represent_frozenordereddict(self, data): return self.represent_dict(dict(data)) @@ -89,7 +88,7 @@ def exists(self, *path_parts) -> bool: def get_expr_hash(self, expr) -> str: expr_hash = dask.base.tokenize(expr) hash_length = config.hash_length - return expr_hash[:hash_length] # TODO: make length of hash as a config + return expr_hash[:hash_length] def save_yaml(self, yaml_dict: Dict[str, Any], expr_hash, filename) -> pathlib.Path: return self.write_yaml(yaml_dict, expr_hash, filename) @@ -101,7 +100,7 @@ def get_build_path(self, expr_hash: str) -> pathlib.Path: return self.ensure_dir(expr_hash) -@attr.s(frozen=True, auto_attribs=True) +@attr.s(frozen=True) class TranslationContext: schema_registry: SchemaRegistry = attr.ib(factory=SchemaRegistry) profiles: FrozenOrderedDict = attr.ib(factory=FrozenOrderedDict) diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index ad6d9dee..e15b7d38 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -9,78 +9,22 @@ import xorq.vendor.ibis.expr.datatypes as dt import xorq.vendor.ibis.expr.operations as ops import xorq.vendor.ibis.expr.operations.temporal as tm -import xorq.vendor.ibis.expr.rules as rlz import xorq.vendor.ibis.expr.types as ir from xorq.expr.relations import CachedNode, Read, RemoteTable, into_backend +from xorq.ibis_yaml.common import ( + _translate_type, + register_from_yaml_handler, + translate_from_yaml, + translate_to_yaml, +) from xorq.ibis_yaml.utils import ( - deserialize_udf_function, freeze, load_storage_from_yaml, - serialize_udf_function, translate_storage, ) -from xorq.vendor.ibis.common.annotations import Argument from xorq.vendor.ibis.expr.operations.relations import Namespace -FROM_YAML_HANDLERS: dict[str, Any] = {} - - -class SchemaRegistry: - def __init__(self): - self.schemas = {} - self.counter = 0 - - def register_schema(self, schema): - frozen_schema = freeze( - {name: _translate_type(dtype) for name, dtype in schema.items()} - ) - - for schema_id, existing_schema in self.schemas.items(): - if existing_schema == frozen_schema: - return schema_id - - schema_id = f"schema_{self.counter}" - self.schemas[schema_id] = frozen_schema - self.counter += 1 - return schema_id - - def _register_expr_schema(self, expr: ir.Expr) -> str: - if hasattr(expr, "schema"): - schema = expr.schema() - return self.register_schema(schema) - return None - - -def register_from_yaml_handler(*op_names: str): - def decorator(func): - for name in op_names: - FROM_YAML_HANDLERS[name] = func - return func - - return decorator - - -@functools.cache -@functools.singledispatch -def translate_from_yaml(yaml_dict: dict, compiler: Any) -> Any: - op_type = yaml_dict["op"] - if op_type not in FROM_YAML_HANDLERS: - raise NotImplementedError(f"No handler for operation {op_type}") - return FROM_YAML_HANDLERS[op_type](yaml_dict, compiler) - - -@functools.cache -@functools.singledispatch -def translate_to_yaml(op: Any, compiler: Any) -> dict: - raise NotImplementedError(f"No translation rule for {type(op)}") - - -@functools.singledispatch -def _translate_type(dtype: dt.DataType) -> dict: - return freeze({"name": type(dtype).__name__, "nullable": dtype.nullable}) - - @_translate_type.register(dt.Timestamp) def _translate_timestamp_type(dtype: dt.Timestamp) -> dict: base = {"name": "Timestamp", "nullable": dtype.nullable} @@ -1025,77 +969,6 @@ def _searched_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: return op.to_expr() -@translate_to_yaml.register(ops.ScalarUDF) -def _scalar_udf_to_yaml(op: ops.ScalarUDF, compiler: Any) -> dict: - print(dir(op)) - if getattr(op.__class__, "__input_type__", None) != ops.udf.InputType.BUILTIN: - raise NotImplementedError( - f"Translation of UDFs with input type {getattr(op.__class__, '__input_type__', None)} is not supported" - ) - arg_names = [ - name - for name in dir(op) - if not name.startswith("__") and name not in op.__class__.__slots__ - ] - - return freeze( - { - "op": "ScalarUDF", - "unique_name": op.__func_name__, - "input_type": "builtin", - "args": [translate_to_yaml(arg, compiler) for arg in op.args], - "type": _translate_type(op.dtype), - "pickle": serialize_udf_function(op.__func__), - "module": op.__module__, - "class_name": op.__class__.__name__, - "arg_names": arg_names, - } - ) - - -@register_from_yaml_handler("ScalarUDF") -def _scalar_udf_from_yaml(yaml_dict: dict, compiler: any) -> any: - encoded_fn = yaml_dict.get("pickle") - if not encoded_fn: - raise ValueError("Missing pickle data for ScalarUDF") - fn = deserialize_udf_function(encoded_fn) - - args = tuple( - translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", []) - ) - if not args: - raise ValueError("ScalarUDF requires at least one argument") - - arg_names = yaml_dict.get("arg_names", [f"arg{i}" for i in range(len(args))]) - - fields = { - name: Argument(pattern=rlz.ValueOf(arg.type()), typehint=arg.type()) - for name, arg in zip(arg_names, args) - } - - bases = (ops.ScalarUDF,) - meta = { - "dtype": dt.dtype(yaml_dict["type"]["name"]), - "__input_type__": ops.udf.InputType.BUILTIN, - "__func__": property(fget=lambda _, f=fn: f), - "__config__": {"volatility": "immutable"}, - "__udf_namespace__": None, - "__module__": yaml_dict.get("module", "__main__"), - "__func_name__": yaml_dict["unique_name"], - } - - kwds = {**fields, **meta} - class_name = yaml_dict.get("class_name", yaml_dict["unique_name"]) - - node = type( - class_name, - bases, - kwds, - ) - - return node(*args).to_expr() - - @register_from_yaml_handler("View") def _view_from_yaml(yaml_dict: dict, compiler: any) -> ir.Expr: underlying = translate_from_yaml(yaml_dict["args"][0], compiler) diff --git a/python/xorq/ibis_yaml/udf.py b/python/xorq/ibis_yaml/udf.py new file mode 100644 index 00000000..508f2500 --- /dev/null +++ b/python/xorq/ibis_yaml/udf.py @@ -0,0 +1,97 @@ +import base64 +from typing import Any + +import cloudpickle + +import xorq.vendor.ibis.expr.datatypes as dt +import xorq.vendor.ibis.expr.operations as ops +import xorq.vendor.ibis.expr.rules as rlz +from xorq.ibis_yaml.common import ( + _translate_type, + register_from_yaml_handler, + translate_from_yaml, + translate_to_yaml, +) +from xorq.ibis_yaml.utils import freeze +from xorq.vendor.ibis.common.annotations import Argument + + +def serialize_udf_function(fn: callable) -> str: + pickled = cloudpickle.dumps(fn) + encoded = base64.b64encode(pickled).decode("ascii") + return encoded + + +def deserialize_udf_function(encoded_fn: str) -> callable: + pickled = base64.b64decode(encoded_fn) + return cloudpickle.loads(pickled) + + +@translate_to_yaml.register(ops.ScalarUDF) +def _scalar_udf_to_yaml(op: ops.ScalarUDF, compiler: Any) -> dict: + if getattr(op.__class__, "__input_type__", None) != ops.udf.InputType.BUILTIN: + raise NotImplementedError( + f"Translation of UDFs with input type {getattr(op.__class__, '__input_type__', None)} is not supported" + ) + arg_names = [ + name + for name in dir(op) + if not name.startswith("__") and name not in op.__class__.__slots__ + ] + + return freeze( + { + "op": "ScalarUDF", + "unique_name": op.__func_name__, + "input_type": "builtin", + "args": [translate_to_yaml(arg, compiler) for arg in op.args], + "type": _translate_type(op.dtype), + "pickle": serialize_udf_function(op.__func__), + "module": op.__module__, + "class_name": op.__class__.__name__, + "arg_names": arg_names, + } + ) + + +@register_from_yaml_handler("ScalarUDF") +def _scalar_udf_from_yaml(yaml_dict: dict, compiler: any) -> any: + encoded_fn = yaml_dict.get("pickle") + if not encoded_fn: + raise ValueError("Missing pickle data for ScalarUDF") + fn = deserialize_udf_function(encoded_fn) + + args = tuple( + translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", []) + ) + if not args: + raise ValueError("ScalarUDF requires at least one argument") + + arg_names = yaml_dict.get("arg_names", [f"arg{i}" for i in range(len(args))]) + + fields = { + name: Argument(pattern=rlz.ValueOf(arg.type()), typehint=arg.type()) + for name, arg in zip(arg_names, args) + } + + bases = (ops.ScalarUDF,) + meta = { + "dtype": dt.dtype(yaml_dict["type"]["name"]), + "__input_type__": ops.udf.InputType.BUILTIN, + "__func__": property(fget=lambda _, f=fn: f), + "__config__": {"volatility": "immutable"}, + "__udf_namespace__": None, + "__module__": yaml_dict.get("module", "__main__"), + "__func_name__": yaml_dict["unique_name"], + } + + kwds = {**fields, **meta} + class_name = yaml_dict.get("class_name", yaml_dict["unique_name"]) + + node = type( + class_name, + bases, + kwds, + ) + + return node(*args).to_expr() diff --git a/python/xorq/ibis_yaml/utils.py b/python/xorq/ibis_yaml/utils.py index 6cb02125..52359554 100644 --- a/python/xorq/ibis_yaml/utils.py +++ b/python/xorq/ibis_yaml/utils.py @@ -1,24 +1,10 @@ -import base64 from collections.abc import Mapping, Sequence from typing import Any, Dict -import cloudpickle - from xorq.common.caching import SourceStorage from xorq.vendor.ibis.common.collections import FrozenOrderedDict -def serialize_udf_function(fn: callable) -> str: - pickled = cloudpickle.dumps(fn) - encoded = base64.b64encode(pickled).decode("ascii") - return encoded - - -def deserialize_udf_function(encoded_fn: str) -> callable: - pickled = base64.b64decode(encoded_fn) - return cloudpickle.loads(pickled) - - def freeze(obj): if isinstance(obj, dict): return FrozenOrderedDict({k: freeze(v) for k, v in obj.items()}) From e8cab225929346e2b108a37b4251ac265421b8a6 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 22 Feb 2025 10:13:56 -0500 Subject: [PATCH 34/45] feat: add metadata.json --- python/xorq/ibis_yaml/compiler.py | 26 ++++++++++++++++++++ python/xorq/ibis_yaml/tests/test_compiler.py | 3 +++ python/xorq/ibis_yaml/translate.py | 3 +++ 3 files changed, 32 insertions(+) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index e4192ea4..1fe951a2 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -1,3 +1,4 @@ +import json import pathlib from pathlib import Path from typing import Any, Dict @@ -6,6 +7,8 @@ import dask import yaml +import xorq as xo +import xorq.common.utils.logging_utils as lu import xorq.vendor.ibis.expr.types as ir from xorq.common.utils.graph_utils import find_all_sources from xorq.ibis_yaml.common import SchemaRegistry @@ -68,6 +71,13 @@ def read_yaml(self, *path_parts) -> Dict[str, Any]: with path.open("r") as f: return yaml.safe_load(f) + def read_json(self, *path_parts) -> Dict[str, Any]: + path = self.get_path(*path_parts) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + with path.open("r") as f: + return json.load(f) + def write_text(self, content: str, *path_parts) -> pathlib.Path: path = self.get_path(*path_parts) path.parent.mkdir(parents=True, exist_ok=True) @@ -178,6 +188,19 @@ def _process_sql_plans( return updated_plans + def _make_metadata(self) -> str: + metadata = { + "current_library_version": xo.__version__, + "metadata_version": "0.0.0", # TODO: make it a real thing + } + if lu._git_is_present(): + git_state = lu.get_git_state(hash_diffs=False) + metadata["git_state"] = git_state + + metadata_json = json.dumps(metadata, indent=2) + + return metadata_json + def _process_deferred_reads( self, deferred_reads: Dict[str, Any], expr_hash: str ) -> Dict[str, Any]: @@ -220,6 +243,9 @@ def compile_expr(self, expr: ir.Expr) -> str: updated_deferred_reads, expr_hash, "deferred_reads.yaml" ) + metadata_json = self._make_metadata() + self.artifact_store.write_text(metadata_json, expr_hash, "metadata.json") + return expr_hash def load_expr(self, expr_hash: str) -> ir.Expr: diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index fe81b212..410a92cf 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -101,7 +101,10 @@ def test_compiler_sql(build_dir): _roundtrip_expr = compiler.load_expr(expr_hash) assert os.path.exists(build_dir / expr_hash / "sql.yaml") + assert os.path.exists(build_dir / expr_hash / "metadata.json") + metadata = compiler.artifact_store.read_json(build_dir, expr_hash, "metadata.json") + assert "current_library_version" in metadata sql_text = pathlib.Path(build_dir / expr_hash / "sql.yaml").read_text() expected_result = ( "queries:\n" diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index e15b7d38..3891cc55 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -17,6 +17,9 @@ translate_from_yaml, translate_to_yaml, ) + +# ruff: noqa: F401 +from xorq.ibis_yaml.udf import _scalar_udf_from_yaml, _scalar_udf_to_yaml from xorq.ibis_yaml.utils import ( freeze, load_storage_from_yaml, From f218d8b90e99bd38968816679c7d456d758d545f Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sat, 22 Feb 2025 11:51:00 -0500 Subject: [PATCH 35/45] refactor: TranslationContext in translate --- python/xorq/ibis_yaml/common.py | 71 +++-- python/xorq/ibis_yaml/compiler.py | 13 +- python/xorq/ibis_yaml/translate.py | 496 +++++++++++++++-------------- 3 files changed, 297 insertions(+), 283 deletions(-) diff --git a/python/xorq/ibis_yaml/common.py b/python/xorq/ibis_yaml/common.py index c566c662..374426b6 100644 --- a/python/xorq/ibis_yaml/common.py +++ b/python/xorq/ibis_yaml/common.py @@ -1,43 +1,17 @@ import functools from typing import Any +import attr + import xorq.vendor.ibis.expr.datatypes as dt import xorq.vendor.ibis.expr.types as ir from xorq.ibis_yaml.utils import freeze +from xorq.vendor.ibis.common.collections import FrozenOrderedDict FROM_YAML_HANDLERS: dict[str, Any] = {} -def register_from_yaml_handler(*op_names: str): - def decorator(func): - for name in op_names: - FROM_YAML_HANDLERS[name] = func - return func - - return decorator - - -@functools.cache -@functools.singledispatch -def translate_from_yaml(yaml_dict: dict, compiler: Any) -> Any: - op_type = yaml_dict["op"] - if op_type not in FROM_YAML_HANDLERS: - raise NotImplementedError(f"No handler for operation {op_type}") - return FROM_YAML_HANDLERS[op_type](yaml_dict, compiler) - - -@functools.cache -@functools.singledispatch -def translate_to_yaml(op: Any, compiler: Any) -> dict: - raise NotImplementedError(f"No translation rule for {type(op)}") - - -@functools.singledispatch -def _translate_type(dtype: dt.DataType) -> dict: - return freeze({"name": type(dtype).__name__, "nullable": dtype.nullable}) - - class SchemaRegistry: def __init__(self): self.schemas = {} @@ -62,3 +36,42 @@ def _register_expr_schema(self, expr: ir.Expr) -> str: schema = expr.schema() return self.register_schema(schema) return None + + +@attr.s(frozen=True) +class TranslationContext: + schema_registry: SchemaRegistry = attr.ib(factory=SchemaRegistry) + profiles: FrozenOrderedDict = attr.ib(factory=FrozenOrderedDict) + definitions: FrozenOrderedDict = attr.ib(factory=lambda: freeze({"schemas": {}})) + + def update_definitions(self, new_definitions: FrozenOrderedDict): + return attr.evolve(self, definitions=new_definitions) + + +def register_from_yaml_handler(*op_names: str): + def decorator(func): + for name in op_names: + FROM_YAML_HANDLERS[name] = func + return func + + return decorator + + +@functools.cache +@functools.singledispatch +def translate_from_yaml(yaml_dict: dict, context: TranslationContext) -> Any: + op_type = yaml_dict["op"] + if op_type not in FROM_YAML_HANDLERS: + raise NotImplementedError(f"No handler for operation {op_type}") + return FROM_YAML_HANDLERS[op_type](yaml_dict, context) + + +@functools.cache +@functools.singledispatch +def translate_to_yaml(op: Any, context: TranslationContext) -> dict: + raise NotImplementedError(f"No translation rule for {type(op)}") + + +@functools.singledispatch +def _translate_type(dtype: dt.DataType) -> dict: + return freeze({"name": type(dtype).__name__, "nullable": dtype.nullable}) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 1fe951a2..904de050 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Any, Dict -import attr import dask import yaml @@ -11,7 +10,7 @@ import xorq.common.utils.logging_utils as lu import xorq.vendor.ibis.expr.types as ir from xorq.common.utils.graph_utils import find_all_sources -from xorq.ibis_yaml.common import SchemaRegistry +from xorq.ibis_yaml.common import SchemaRegistry, TranslationContext from xorq.ibis_yaml.config import config from xorq.ibis_yaml.sql import generate_sql_plans from xorq.ibis_yaml.translate import ( @@ -110,16 +109,6 @@ def get_build_path(self, expr_hash: str) -> pathlib.Path: return self.ensure_dir(expr_hash) -@attr.s(frozen=True) -class TranslationContext: - schema_registry: SchemaRegistry = attr.ib(factory=SchemaRegistry) - profiles: FrozenOrderedDict = attr.ib(factory=FrozenOrderedDict) - definitions: FrozenOrderedDict = attr.ib(factory=lambda: freeze({"schemas": {}})) - - def update_definitions(self, new_definitions: FrozenOrderedDict): - return attr.evolve(self, definitions=new_definitions) - - class YamlExpressionTranslator: def __init__(self): pass diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index 3891cc55..289eec67 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -12,6 +12,7 @@ import xorq.vendor.ibis.expr.types as ir from xorq.expr.relations import CachedNode, Read, RemoteTable, into_backend from xorq.ibis_yaml.common import ( + TranslationContext, _translate_type, register_from_yaml_handler, translate_from_yaml, @@ -127,34 +128,36 @@ def _translate_literal_value(value: Any, dtype: dt.DataType) -> Any: @translate_to_yaml.register(ir.Expr) -def _expr_to_yaml(expr: ir.Expr, compiler: any) -> dict: - return translate_to_yaml(expr.op(), compiler) +def _expr_to_yaml(expr: ir.Expr, context: any) -> dict: + return translate_to_yaml(expr.op(), context) @translate_to_yaml.register(ops.WindowFunction) -def _window_function_to_yaml(op: ops.WindowFunction, compiler: Any) -> dict: +def _window_function_to_yaml( + op: ops.WindowFunction, context: TranslationContext +) -> dict: result = { "op": "WindowFunction", - "args": [translate_to_yaml(op.func, compiler)], + "args": [translate_to_yaml(op.func, context)], "type": _translate_type(op.dtype), } if op.group_by: - result["group_by"] = [translate_to_yaml(expr, compiler) for expr in op.group_by] + result["group_by"] = [translate_to_yaml(expr, context) for expr in op.group_by] if op.order_by: - result["order_by"] = [translate_to_yaml(expr, compiler) for expr in op.order_by] + result["order_by"] = [translate_to_yaml(expr, context) for expr in op.order_by] if op.start is not None: result["start"] = ( - translate_to_yaml(op.start.value, compiler)["value"] + translate_to_yaml(op.start.value, context)["value"] if isinstance(op.start, ops.WindowBoundary) else op.start ) if op.end is not None: result["end"] = ( - translate_to_yaml(op.end.value, compiler)["value"] + translate_to_yaml(op.end.value, context)["value"] if isinstance(op.end, ops.WindowBoundary) else op.end ) @@ -163,10 +166,10 @@ def _window_function_to_yaml(op: ops.WindowFunction, compiler: Any) -> dict: @register_from_yaml_handler("WindowFunction") -def _window_function_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - func = translate_from_yaml(yaml_dict["args"][0], compiler) - group_by = [translate_from_yaml(g, compiler) for g in yaml_dict.get("group_by", [])] - order_by = [translate_from_yaml(o, compiler) for o in yaml_dict.get("order_by", [])] +def _window_function_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + func = translate_from_yaml(yaml_dict["args"][0], context) + group_by = [translate_from_yaml(g, context) for g in yaml_dict.get("group_by", [])] + order_by = [translate_from_yaml(o, context) for o in yaml_dict.get("order_by", [])] start = ibis.literal(yaml_dict["start"]) if "start" in yaml_dict else None end = ibis.literal(yaml_dict["end"]) if "end" in yaml_dict else None window = ibis.window( @@ -176,11 +179,13 @@ def _window_function_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.WindowBoundary) -def _window_boundary_to_yaml(op: ops.WindowBoundary, compiler: Any) -> dict: +def _window_boundary_to_yaml( + op: ops.WindowBoundary, context: TranslationContext +) -> dict: return freeze( { "op": "WindowBoundary", - "value": translate_to_yaml(op.value, compiler), + "value": translate_to_yaml(op.value, context), "preceding": op.preceding, "type": _translate_type(op.dtype), } @@ -188,18 +193,18 @@ def _window_boundary_to_yaml(op: ops.WindowBoundary, compiler: Any) -> dict: @register_from_yaml_handler("WindowBoundary") -def _window_boundary_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - value = translate_from_yaml(yaml_dict["value"], compiler) +def _window_boundary_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + value = translate_from_yaml(yaml_dict["value"], context) return ops.WindowBoundary(value, preceding=yaml_dict["preceding"]) @translate_to_yaml.register(ops.Node) -def _base_op_to_yaml(op: ops.Node, compiler: Any) -> dict: +def _base_op_to_yaml(op: ops.Node, context: TranslationContext) -> dict: return freeze( { "op": type(op).__name__, "args": [ - translate_to_yaml(arg, compiler) + translate_to_yaml(arg, context) for arg in op.args if isinstance(arg, (ops.Value, ops.Node)) ], @@ -208,8 +213,8 @@ def _base_op_to_yaml(op: ops.Node, compiler: Any) -> dict: @translate_to_yaml.register(ops.UnboundTable) -def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: - schema_id = compiler.schema_registry.register_schema(op.schema) +def _unbound_table_to_yaml(op: ops.UnboundTable, context: TranslationContext) -> dict: + schema_id = context.schema_registry.register_schema(op.schema) namespace_dict = freeze( { "catalog": op.namespace.catalog, @@ -227,12 +232,12 @@ def _unbound_table_to_yaml(op: ops.UnboundTable, compiler: Any) -> dict: @register_from_yaml_handler("UnboundTable") -def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _unbound_table_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: table_name = yaml_dict["name"] schema_ref = yaml_dict["schema_ref"] try: - schema_def = compiler.definitions["schemas"][schema_ref] + schema_def = context.definitions["schemas"][schema_ref] except KeyError: raise ValueError(f"Schema {schema_ref} not found in definitions") namespace_dict = yaml_dict.get("namespace", {}) @@ -246,9 +251,9 @@ def _unbound_table_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.DatabaseTable) -def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: +def _database_table_to_yaml(op: ops.DatabaseTable, context: TranslationContext) -> dict: profile_name = op.source._profile.hash_name - schema_id = compiler.schema_registry.register_schema(op.schema) + schema_id = context.schema_registry.register_schema(op.schema) namespace_dict = freeze( { "catalog": op.namespace.catalog, @@ -268,7 +273,7 @@ def _database_table_to_yaml(op: ops.DatabaseTable, compiler: Any) -> dict: @register_from_yaml_handler("DatabaseTable") -def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: +def database_table_from_yaml(yaml_dict: dict, context: TranslationContext) -> ibis.Expr: profile_name = yaml_dict.get("profile") table_name = yaml_dict.get("table") namespace_dict = yaml_dict.get("namespace", {}) @@ -276,7 +281,7 @@ def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: database = namespace_dict.get("database") # we should validate that schema is the same schema_ref = yaml_dict.get("schema_ref") - schema_def = compiler.definitions["schemas"][schema_ref] + schema_def = context.definitions["schemas"][schema_ref] fields = [] for name, dtype_yaml in schema_def.items(): dtype = _type_from_yaml(dtype_yaml) @@ -284,9 +289,9 @@ def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: schema = ibis.Schema.from_tuples(fields) try: - con = compiler.profiles[profile_name] + con = context.profiles[profile_name] except KeyError: - raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + raise ValueError(f"Profile {profile_name!r} not found in context.profiles") return ops.DatabaseTable( schema=schema, source=con, @@ -296,27 +301,27 @@ def database_table_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.Expr: @translate_to_yaml.register(CachedNode) -def _cached_node_to_yaml(op: CachedNode, compiler: any) -> dict: - schema_id = compiler.schema_registry.register_schema(op.schema) +def _cached_node_to_yaml(op: CachedNode, context: any) -> dict: + schema_id = context.schema_registry.register_schema(op.schema) # source should be called profile_name return freeze( { "op": "CachedNode", "schema_ref": schema_id, - "parent": translate_to_yaml(op.parent, compiler), + "parent": translate_to_yaml(op.parent, context), "source": op.source._profile.hash_name, - "storage": translate_storage(op.storage, compiler), + "storage": translate_storage(op.storage, context), "values": freeze(op.values), } ) @register_from_yaml_handler("CachedNode") -def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: +def _cached_node_from_yaml(yaml_dict: dict, context: any) -> ibis.Expr: schema_ref = yaml_dict["schema_ref"] try: - schema_def = compiler.definitions["schemas"][schema_ref] + schema_def = context.definitions["schemas"][schema_ref] except KeyError: raise ValueError(f"Schema {schema_ref} not found in definitions") @@ -324,13 +329,13 @@ def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() } - parent_expr = translate_from_yaml(yaml_dict["parent"], compiler) + parent_expr = translate_from_yaml(yaml_dict["parent"], context) profile_name = yaml_dict.get("source") try: - source = compiler.profiles[profile_name] + source = context.profiles[profile_name] except KeyError: - raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") - storage = load_storage_from_yaml(yaml_dict["storage"], compiler) + raise ValueError(f"Profile {profile_name!r} not found in context.profiles") + storage = load_storage_from_yaml(yaml_dict["storage"], context) op = CachedNode( schema=schema, @@ -342,10 +347,10 @@ def _cached_node_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: @translate_to_yaml.register(RemoteTable) -def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: +def _remotetable_to_yaml(op: RemoteTable, context: any) -> dict: profile_name = op.source._profile.hash_name - remote_expr_yaml = translate_to_yaml(op.remote_expr, compiler) - schema_id = compiler.schema_registry.register_schema(op.schema) + remote_expr_yaml = translate_to_yaml(op.remote_expr, context) + schema_id = context.schema_registry.register_schema(op.schema) # TODO: change profile to profile_name return freeze( { @@ -359,7 +364,7 @@ def _remotetable_to_yaml(op: RemoteTable, compiler: any) -> dict: @register_from_yaml_handler("RemoteTable") -def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: +def _remotetable_from_yaml(yaml_dict: dict, context: any) -> ibis.Expr: profile_name = yaml_dict.get("profile") table_name = yaml_dict.get("table") remote_expr_yaml = yaml_dict.get("remote_expr") @@ -368,19 +373,19 @@ def _remotetable_from_yaml(yaml_dict: dict, compiler: any) -> ibis.Expr: "Missing keys in RemoteTable YAML; ensure 'profile_name' are present." ) try: - con = compiler.profiles[profile_name] + con = context.profiles[profile_name] except KeyError: - raise ValueError(f"Profile {profile_name!r} not found in compiler.profiles") + raise ValueError(f"Profile {profile_name!r} not found in context.profiles") - remote_expr = translate_from_yaml(remote_expr_yaml, compiler) + remote_expr = translate_from_yaml(remote_expr_yaml, context) remote_table_expr = into_backend(remote_expr, con, table_name) return remote_table_expr @translate_to_yaml.register(Read) -def _read_to_yaml(op: Read, compiler: Any) -> dict: - schema_id = compiler.schema_registry.register_schema(op.schema) +def _read_to_yaml(op: Read, context: TranslationContext) -> dict: + schema_id = context.schema_registry.register_schema(op.schema) profile_hash_name = ( op.source._profile.hash_name if hasattr(op.source, "_profile") else None ) @@ -399,9 +404,9 @@ def _read_to_yaml(op: Read, compiler: Any) -> dict: @register_from_yaml_handler("Read") -def _read_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _read_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: schema_ref = yaml_dict["schema_ref"] - schema_def = compiler.definitions["schemas"][schema_ref] + schema_def = context.definitions["schemas"][schema_ref] schema = { name: _type_from_yaml(dtype_yaml) for name, dtype_yaml in schema_def.items() @@ -409,7 +414,7 @@ def _read_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: profile_hash_name = yaml_dict.get("profile") - source = compiler.profiles[profile_hash_name] + source = context.profiles[profile_hash_name] read_op = Read( method_name=yaml_dict["method_name"], @@ -423,26 +428,26 @@ def _read_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.Literal) -def _literal_to_yaml(op: ops.Literal, compiler: Any) -> dict: +def _literal_to_yaml(op: ops.Literal, context: TranslationContext) -> dict: value = _translate_literal_value(op.value, op.dtype) return freeze({"op": "Literal", "value": value, "type": _translate_type(op.dtype)}) @register_from_yaml_handler("Literal") -def _literal_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _literal_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: value = yaml_dict["value"] dtype = _type_from_yaml(yaml_dict["type"]) return ibis.literal(value, type=dtype) @translate_to_yaml.register(ops.ValueOp) -def _value_op_to_yaml(op: ops.ValueOp, compiler: Any) -> dict: +def _value_op_to_yaml(op: ops.ValueOp, context: TranslationContext) -> dict: return freeze( { "op": type(op).__name__, "type": _translate_type(op.dtype), "args": [ - translate_to_yaml(arg, compiler) + translate_to_yaml(arg, context) for arg in op.args if isinstance(arg, (ops.Value, ops.Node)) ], @@ -451,90 +456,90 @@ def _value_op_to_yaml(op: ops.ValueOp, compiler: Any) -> dict: @register_from_yaml_handler("ValueOp") -def _value_op_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _value_op_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] method_name = yaml_dict["op"].lower() method = getattr(args[0], method_name) return method(*args[1:]) @translate_to_yaml.register(ops.StringUnary) -def _string_unary_to_yaml(op: ops.StringUnary, compiler: Any) -> dict: +def _string_unary_to_yaml(op: ops.StringUnary, context: TranslationContext) -> dict: return freeze( { "op": type(op).__name__, - "args": [translate_to_yaml(op.arg, compiler)], + "args": [translate_to_yaml(op.arg, context)], "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("StringUnary") -def _string_unary_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _string_unary_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) method_name = yaml_dict["op"].lower() return getattr(arg, method_name)() @translate_to_yaml.register(ops.Substring) -def _substring_to_yaml(op: ops.Substring, compiler: Any) -> dict: +def _substring_to_yaml(op: ops.Substring, context: TranslationContext) -> dict: args = [ - translate_to_yaml(op.arg, compiler), - translate_to_yaml(op.start, compiler), + translate_to_yaml(op.arg, context), + translate_to_yaml(op.start, context), ] if op.length is not None: - args.append(translate_to_yaml(op.length, compiler)) + args.append(translate_to_yaml(op.length, context)) return freeze({"op": "Substring", "args": args, "type": _translate_type(op.dtype)}) @register_from_yaml_handler("Substring") -def _substring_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _substring_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].substr(args[1], args[2] if len(args) > 2 else None) @translate_to_yaml.register(ops.StringLength) -def _string_length_to_yaml(op: ops.StringLength, compiler: Any) -> dict: +def _string_length_to_yaml(op: ops.StringLength, context: TranslationContext) -> dict: return freeze( { "op": "StringLength", - "args": [translate_to_yaml(op.arg, compiler)], + "args": [translate_to_yaml(op.arg, context)], "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("StringLength") -def _string_length_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _string_length_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return arg.length() @translate_to_yaml.register(ops.StringConcat) -def _string_concat_to_yaml(op: ops.StringConcat, compiler: Any) -> dict: +def _string_concat_to_yaml(op: ops.StringConcat, context: TranslationContext) -> dict: return freeze( { "op": "StringConcat", - "args": [translate_to_yaml(arg, compiler) for arg in op.arg], + "args": [translate_to_yaml(arg, context) for arg in op.arg], "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("StringConcat") -def _string_concat_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _string_concat_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return functools.reduce(lambda x, y: x.concat(y), args) @translate_to_yaml.register(ops.BinaryOp) -def _binary_op_to_yaml(op: ops.BinaryOp, compiler: Any) -> dict: +def _binary_op_to_yaml(op: ops.BinaryOp, context: TranslationContext) -> dict: return freeze( { "op": type(op).__name__, "args": [ - translate_to_yaml(op.left, compiler), - translate_to_yaml(op.right, compiler), + translate_to_yaml(op.left, context), + translate_to_yaml(op.right, context), ], "type": _translate_type(op.dtype), } @@ -542,52 +547,51 @@ def _binary_op_to_yaml(op: ops.BinaryOp, compiler: Any) -> dict: @register_from_yaml_handler("BinaryOp") -def _binary_op_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _binary_op_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] op_name = yaml_dict["op"].lower() return getattr(args[0], op_name)(args[1]) @translate_to_yaml.register(ops.Filter) -def _filter_to_yaml(op: ops.Filter, compiler: Any) -> dict: +def _filter_to_yaml(op: ops.Filter, context: TranslationContext) -> dict: return freeze( { "op": "Filter", - "parent": translate_to_yaml(op.parent, compiler), - "predicates": [translate_to_yaml(pred, compiler) for pred in op.predicates], + "parent": translate_to_yaml(op.parent, context), + "predicates": [translate_to_yaml(pred, context) for pred in op.predicates], } ) @register_from_yaml_handler("Filter") -def _filter_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) +def _filter_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) predicates = [ - translate_from_yaml(pred, compiler) for pred in yaml_dict["predicates"] + translate_from_yaml(pred, context) for pred in yaml_dict["predicates"] ] filter_op = ops.Filter(parent, predicates) return filter_op.to_expr() @translate_to_yaml.register(ops.Project) -def _project_to_yaml(op: ops.Project, compiler: Any) -> dict: +def _project_to_yaml(op: ops.Project, context: TranslationContext) -> dict: return freeze( { "op": "Project", - "parent": translate_to_yaml(op.parent, compiler), + "parent": translate_to_yaml(op.parent, context), "values": { - name: translate_to_yaml(val, compiler) - for name, val in op.values.items() + name: translate_to_yaml(val, context) for name, val in op.values.items() }, } ) @register_from_yaml_handler("Project") -def _project_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) +def _project_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) values = { - name: translate_from_yaml(val, compiler) + name: translate_from_yaml(val, context) for name, val in yaml_dict["values"].items() } projected = parent.projection(values) @@ -595,14 +599,14 @@ def _project_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.Aggregate) -def _aggregate_to_yaml(op: ops.Aggregate, compiler: Any) -> dict: +def _aggregate_to_yaml(op: ops.Aggregate, context: TranslationContext) -> dict: return freeze( { "op": "Aggregate", - "parent": translate_to_yaml(op.parent, compiler), - "by": [translate_to_yaml(group, compiler) for group in op.groups.values()], + "parent": translate_to_yaml(op.parent, context), + "by": [translate_to_yaml(group, context) for group in op.groups.values()], "metrics": { - name: translate_to_yaml(metric, compiler) + name: translate_to_yaml(metric, context) for name, metric in op.metrics.items() }, } @@ -610,14 +614,14 @@ def _aggregate_to_yaml(op: ops.Aggregate, compiler: Any) -> dict: @register_from_yaml_handler("Aggregate") -def _aggregate_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) +def _aggregate_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) groups = tuple( - translate_from_yaml(group, compiler) for group in yaml_dict.get("by", []) + translate_from_yaml(group, context) for group in yaml_dict.get("by", []) ) metrics = { - name: translate_from_yaml(metric, compiler) + name: translate_from_yaml(metric, context) for name, metric in yaml_dict.get("metrics", {}).items() } @@ -630,41 +634,39 @@ def _aggregate_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.JoinChain) -def _join_to_yaml(op: ops.JoinChain, compiler: Any) -> dict: +def _join_to_yaml(op: ops.JoinChain, context: TranslationContext) -> dict: result = { "op": "JoinChain", - "first": translate_to_yaml(op.first, compiler), + "first": translate_to_yaml(op.first, context), "rest": [ { "how": link.how, - "table": translate_to_yaml(link.table, compiler), + "table": translate_to_yaml(link.table, context), "predicates": [ - translate_to_yaml(pred, compiler) for pred in link.predicates + translate_to_yaml(pred, context) for pred in link.predicates ], } for link in op.rest ], } result["values"] = { - name: translate_to_yaml(val, compiler) for name, val in op.values.items() + name: translate_to_yaml(val, context) for name, val in op.values.items() } return freeze(result) @register_from_yaml_handler("JoinChain") -def _join_chain_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - first = translate_from_yaml(yaml_dict["first"], compiler) +def _join_chain_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + first = translate_from_yaml(yaml_dict["first"], context) result = first for join in yaml_dict["rest"]: - table = translate_from_yaml(join["table"], compiler) - predicates = [ - translate_from_yaml(pred, compiler) for pred in join["predicates"] - ] + table = translate_from_yaml(join["table"], context) + predicates = [translate_from_yaml(pred, context) for pred in join["predicates"]] result = result.join(table, predicates, how=join["how"]) values = { - name: translate_from_yaml(val, compiler) + name: translate_from_yaml(val, context) for name, val in yaml_dict["values"].items() } result = result.select(values) @@ -672,30 +674,30 @@ def _join_chain_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.Sort) -def _sort_to_yaml(op: ops.Sort, compiler: Any) -> dict: +def _sort_to_yaml(op: ops.Sort, context: TranslationContext) -> dict: return freeze( { "op": "Sort", - "parent": translate_to_yaml(op.parent, compiler), - "keys": [translate_to_yaml(key, compiler) for key in op.keys], + "parent": translate_to_yaml(op.parent, context), + "keys": [translate_to_yaml(key, context) for key in op.keys], } ) @register_from_yaml_handler("Sort") -def _sort_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) - keys = tuple(translate_from_yaml(key, compiler) for key in yaml_dict["keys"]) +def _sort_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) + keys = tuple(translate_from_yaml(key, context) for key in yaml_dict["keys"]) sort_op = ops.Sort(parent, keys=keys) return sort_op.to_expr() @translate_to_yaml.register(ops.SortKey) -def _sort_key_to_yaml(op: ops.SortKey, compiler: Any) -> dict: +def _sort_key_to_yaml(op: ops.SortKey, context: TranslationContext) -> dict: return freeze( { "op": "SortKey", - "arg": translate_to_yaml(op.expr, compiler), + "arg": translate_to_yaml(op.expr, context), "ascending": op.ascending, "nulls_first": op.nulls_first, "type": _translate_type(op.dtype), @@ -704,19 +706,19 @@ def _sort_key_to_yaml(op: ops.SortKey, compiler: Any) -> dict: @register_from_yaml_handler("SortKey") -def _sort_key_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - expr = translate_from_yaml(yaml_dict["arg"], compiler) +def _sort_key_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + expr = translate_from_yaml(yaml_dict["arg"], context) ascending = yaml_dict.get("ascending", True) nulls_first = yaml_dict.get("nulls_first", False) return ops.SortKey(expr, ascending=ascending, nulls_first=nulls_first).to_expr() @translate_to_yaml.register(ops.Limit) -def _limit_to_yaml(op: ops.Limit, compiler: Any) -> dict: +def _limit_to_yaml(op: ops.Limit, context: TranslationContext) -> dict: return freeze( { "op": "Limit", - "parent": translate_to_yaml(op.parent, compiler), + "parent": translate_to_yaml(op.parent, context), "n": op.n, "offset": op.offset, } @@ -724,70 +726,74 @@ def _limit_to_yaml(op: ops.Limit, compiler: Any) -> dict: @register_from_yaml_handler("Limit") -def _limit_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) +def _limit_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) return parent.limit(yaml_dict["n"], offset=yaml_dict["offset"]) @translate_to_yaml.register(ops.ScalarSubquery) -def _scalar_subquery_to_yaml(op: ops.ScalarSubquery, compiler: Any) -> dict: +def _scalar_subquery_to_yaml( + op: ops.ScalarSubquery, context: TranslationContext +) -> dict: return freeze( { "op": "ScalarSubquery", - "args": [translate_to_yaml(arg, compiler) for arg in op.args], + "args": [translate_to_yaml(arg, context) for arg in op.args], "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("ScalarSubquery") -def _scalar_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - subquery = translate_from_yaml(yaml_dict["args"][0], compiler) +def _scalar_subquery_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + subquery = translate_from_yaml(yaml_dict["args"][0], context) return ops.ScalarSubquery(subquery).to_expr() @translate_to_yaml.register(ops.ExistsSubquery) -def _exists_subquery_to_yaml(op: ops.ExistsSubquery, compiler: Any) -> dict: +def _exists_subquery_to_yaml( + op: ops.ExistsSubquery, context: TranslationContext +) -> dict: return freeze( { "op": "ExistsSubquery", - "rel": translate_to_yaml(op.rel, compiler), + "rel": translate_to_yaml(op.rel, context), "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("ExistsSubquery") -def _exists_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - rel = translate_from_yaml(yaml_dict["rel"], compiler) +def _exists_subquery_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + rel = translate_from_yaml(yaml_dict["rel"], context) return ops.ExistsSubquery(rel).to_expr() @translate_to_yaml.register(ops.InSubquery) -def _in_subquery_to_yaml(op: ops.InSubquery, compiler: Any) -> dict: +def _in_subquery_to_yaml(op: ops.InSubquery, context: TranslationContext) -> dict: return freeze( { "op": "InSubquery", - "needle": translate_to_yaml(op.needle, compiler), - "haystack": translate_to_yaml(op.rel, compiler), + "needle": translate_to_yaml(op.needle, context), + "haystack": translate_to_yaml(op.rel, context), "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("InSubquery") -def _in_subquery_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - needle = translate_from_yaml(yaml_dict["needle"], compiler) - haystack = translate_from_yaml(yaml_dict["haystack"], compiler) +def _in_subquery_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + needle = translate_from_yaml(yaml_dict["needle"], context) + haystack = translate_from_yaml(yaml_dict["haystack"], context) return ops.InSubquery(haystack, needle).to_expr() @translate_to_yaml.register(ops.Field) -def _field_to_yaml(op: ops.Field, compiler: Any) -> dict: +def _field_to_yaml(op: ops.Field, context: TranslationContext) -> dict: result = { "op": "Field", "name": op.name, - "relation": translate_to_yaml(op.rel, compiler), + "relation": translate_to_yaml(op.rel, context), "type": _translate_type(op.dtype), } @@ -800,8 +806,8 @@ def _field_to_yaml(op: ops.Field, compiler: Any) -> dict: @register_from_yaml_handler("Field") -def field_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - relation = translate_from_yaml(yaml_dict["relation"], compiler) +def field_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + relation = translate_from_yaml(yaml_dict["relation"], context) target_name = yaml_dict["name"] source_name = yaml_dict.get("original_name", target_name) field = relation[source_name] @@ -812,13 +818,13 @@ def field_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.InValues) -def _in_values_to_yaml(op: ops.InValues, compiler: Any) -> dict: +def _in_values_to_yaml(op: ops.InValues, context: TranslationContext) -> dict: return freeze( { "op": "InValues", "args": [ - translate_to_yaml(op.value, compiler), - *[translate_to_yaml(opt, compiler) for opt in op.options], + translate_to_yaml(op.value, context), + *[translate_to_yaml(opt, context) for opt in op.options], ], "type": _translate_type(op.dtype), } @@ -826,71 +832,71 @@ def _in_values_to_yaml(op: ops.InValues, compiler: Any) -> dict: @register_from_yaml_handler("InValues") -def _in_values_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - value = translate_from_yaml(yaml_dict["args"][0], compiler) - options = tuple(translate_from_yaml(opt, compiler) for opt in yaml_dict["args"][1:]) +def _in_values_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + value = translate_from_yaml(yaml_dict["args"][0], context) + options = tuple(translate_from_yaml(opt, context) for opt in yaml_dict["args"][1:]) return ops.InValues(value, options).to_expr() @translate_to_yaml.register(ops.SimpleCase) -def _simple_case_to_yaml(op: ops.SimpleCase, compiler: Any) -> dict: +def _simple_case_to_yaml(op: ops.SimpleCase, context: TranslationContext) -> dict: return freeze( { "op": "SimpleCase", - "base": translate_to_yaml(op.base, compiler), - "cases": [translate_to_yaml(case, compiler) for case in op.cases], - "results": [translate_to_yaml(result, compiler) for result in op.results], - "default": translate_to_yaml(op.default, compiler), + "base": translate_to_yaml(op.base, context), + "cases": [translate_to_yaml(case, context) for case in op.cases], + "results": [translate_to_yaml(result, context) for result in op.results], + "default": translate_to_yaml(op.default, context), "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("SimpleCase") -def _simple_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - base = translate_from_yaml(yaml_dict["base"], compiler) - cases = tuple(translate_from_yaml(case, compiler) for case in yaml_dict["cases"]) +def _simple_case_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + base = translate_from_yaml(yaml_dict["base"], context) + cases = tuple(translate_from_yaml(case, context) for case in yaml_dict["cases"]) results = tuple( - translate_from_yaml(result, compiler) for result in yaml_dict["results"] + translate_from_yaml(result, context) for result in yaml_dict["results"] ) - default = translate_from_yaml(yaml_dict["default"], compiler) + default = translate_from_yaml(yaml_dict["default"], context) return ops.SimpleCase(base, cases, results, default).to_expr() @translate_to_yaml.register(ops.IfElse) -def _if_else_to_yaml(op: ops.IfElse, compiler: Any) -> dict: +def _if_else_to_yaml(op: ops.IfElse, context: TranslationContext) -> dict: return freeze( { "op": "IfElse", - "bool_expr": translate_to_yaml(op.bool_expr, compiler), - "true_expr": translate_to_yaml(op.true_expr, compiler), - "false_null_expr": translate_to_yaml(op.false_null_expr, compiler), + "bool_expr": translate_to_yaml(op.bool_expr, context), + "true_expr": translate_to_yaml(op.true_expr, context), + "false_null_expr": translate_to_yaml(op.false_null_expr, context), "type": _translate_type(op.dtype), } ) @register_from_yaml_handler("IfElse") -def _if_else_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - bool_expr = translate_from_yaml(yaml_dict["bool_expr"], compiler) - true_expr = translate_from_yaml(yaml_dict["true_expr"], compiler) - false_null_expr = translate_from_yaml(yaml_dict["false_null_expr"], compiler) +def _if_else_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + bool_expr = translate_from_yaml(yaml_dict["bool_expr"], context) + true_expr = translate_from_yaml(yaml_dict["true_expr"], context) + false_null_expr = translate_from_yaml(yaml_dict["false_null_expr"], context) return ops.IfElse(bool_expr, true_expr, false_null_expr).to_expr() @translate_to_yaml.register(ops.CountDistinct) -def _count_distinct_to_yaml(op: ops.CountDistinct, compiler: Any) -> dict: +def _count_distinct_to_yaml(op: ops.CountDistinct, context: TranslationContext) -> dict: return freeze( { "op": "CountDistinct", - "args": [translate_to_yaml(op.arg, compiler)], + "args": [translate_to_yaml(op.arg, context)], "type": _translate_type(op.dtype), } ) @translate_to_yaml.register(ops.RankBase) -def _rank_base_to_yaml(op: ops.RankBase, compiler: Any) -> dict: +def _rank_base_to_yaml(op: ops.RankBase, context: TranslationContext) -> dict: return freeze( { "op": type(op).__name__, @@ -899,28 +905,28 @@ def _rank_base_to_yaml(op: ops.RankBase, compiler: Any) -> dict: @register_from_yaml_handler("RowNumber") -def _row_number_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _row_number_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: return ibis.row_number() @register_from_yaml_handler("CountDistinct") -def _count_distinct_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _count_distinct_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return arg.nunique() @translate_to_yaml.register(ops.SelfReference) -def _self_reference_to_yaml(op: ops.SelfReference, compiler: Any) -> dict: +def _self_reference_to_yaml(op: ops.SelfReference, context: TranslationContext) -> dict: result = {"op": "SelfReference", "identifier": op.identifier} if op.args: - result["args"] = [translate_to_yaml(op.args[0], compiler)] + result["args"] = [translate_to_yaml(op.args[0], context)] return freeze(result) @register_from_yaml_handler("SelfReference") -def _self_reference_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _self_reference_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: if "args" in yaml_dict and yaml_dict["args"]: - underlying = translate_from_yaml(yaml_dict["args"][0], compiler) + underlying = translate_from_yaml(yaml_dict["args"][0], context) else: if underlying is None: raise NotImplementedError("No relation available for SelfReference") @@ -932,49 +938,49 @@ def _self_reference_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @translate_to_yaml.register(ops.DropColumns) -def _drop_columns_to_yaml(op: ops.DropColumns, compiler: Any) -> dict: +def _drop_columns_to_yaml(op: ops.DropColumns, context: TranslationContext) -> dict: return freeze( { "op": "DropColumns", - "parent": translate_to_yaml(op.parent, compiler), + "parent": translate_to_yaml(op.parent, context), "columns_to_drop": list(op.columns_to_drop), } ) @register_from_yaml_handler("DropColumns") -def _drop_columns_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - parent = translate_from_yaml(yaml_dict["parent"], compiler) +def _drop_columns_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + parent = translate_from_yaml(yaml_dict["parent"], context) columns = frozenset(yaml_dict["columns_to_drop"]) op = ops.DropColumns(parent, columns) return op.to_expr() @translate_to_yaml.register(ops.SearchedCase) -def _searched_case_to_yaml(op: ops.SearchedCase, compiler: Any) -> dict: +def _searched_case_to_yaml(op: ops.SearchedCase, context: TranslationContext) -> dict: return freeze( { "op": "SearchedCase", - "cases": [translate_to_yaml(case, compiler) for case in op.cases], - "results": [translate_to_yaml(result, compiler) for result in op.results], - "default": translate_to_yaml(op.default, compiler), + "cases": [translate_to_yaml(case, context) for case in op.cases], + "results": [translate_to_yaml(result, context) for result in op.results], + "default": translate_to_yaml(op.default, context), "dtype": _translate_type(op.dtype), } ) @register_from_yaml_handler("SearchedCase") -def _searched_case_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - cases = [translate_from_yaml(case, compiler) for case in yaml_dict["cases"]] - results = [translate_from_yaml(result, compiler) for result in yaml_dict["results"]] - default = translate_from_yaml(yaml_dict["default"], compiler) +def _searched_case_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + cases = [translate_from_yaml(case, context) for case in yaml_dict["cases"]] + results = [translate_from_yaml(result, context) for result in yaml_dict["results"]] + default = translate_from_yaml(yaml_dict["default"], context) op = ops.SearchedCase(cases, results, default) return op.to_expr() @register_from_yaml_handler("View") -def _view_from_yaml(yaml_dict: dict, compiler: any) -> ir.Expr: - underlying = translate_from_yaml(yaml_dict["args"][0], compiler) +def _view_from_yaml(yaml_dict: dict, context: any) -> ir.Expr: + underlying = translate_from_yaml(yaml_dict["args"][0], context) alias = yaml_dict.get("name") if alias: return underlying.alias(alias) @@ -982,15 +988,17 @@ def _view_from_yaml(yaml_dict: dict, compiler: any) -> ir.Expr: @register_from_yaml_handler("Mean") -def _mean_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _mean_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].mean() @register_from_yaml_handler("Add", "Subtract", "Multiply", "Divide") -def _binary_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - left = translate_from_yaml(yaml_dict["args"][0], compiler) - right = translate_from_yaml(yaml_dict["args"][1], compiler) +def _binary_arithmetic_from_yaml( + yaml_dict: dict, context: TranslationContext +) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], context) + right = translate_from_yaml(yaml_dict["args"][1], context) op_map = { "Add": lambda left, right: left + right, "Subtract": lambda left, right: left - right, @@ -1004,54 +1012,56 @@ def _binary_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @register_from_yaml_handler("Repeat") -def _repeat_from_yaml(yaml_dict: dict, compiler: Any) -> ibis.expr.types.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) - times = translate_from_yaml(yaml_dict["args"][1], compiler) +def _repeat_from_yaml( + yaml_dict: dict, context: TranslationContext +) -> ibis.expr.types.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) + times = translate_from_yaml(yaml_dict["args"][1], context) return ops.Repeat(arg, times).to_expr() @register_from_yaml_handler("Sum") -def _sum_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _sum_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].sum() @register_from_yaml_handler("Min") -def _min_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _min_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].min() @register_from_yaml_handler("Max") -def _max_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _max_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].max() @register_from_yaml_handler("Abs") -def _abs_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _abs_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return arg.abs() @register_from_yaml_handler("Count") -def _count_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _count_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return arg.count() @register_from_yaml_handler("JoinReference") -def _join_reference_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _join_reference_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: table_yaml = yaml_dict["args"][0] - return translate_from_yaml(table_yaml, compiler) + return translate_from_yaml(table_yaml, context) @register_from_yaml_handler( "Equals", "NotEquals", "GreaterThan", "GreaterEqual", "LessThan", "LessEqual" ) -def _binary_compare_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - left = translate_from_yaml(yaml_dict["args"][0], compiler) - right = translate_from_yaml(yaml_dict["args"][1], compiler) +def _binary_compare_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], context) + right = translate_from_yaml(yaml_dict["args"][1], context) op_map = { "Equals": lambda left, right: left == right, @@ -1069,14 +1079,14 @@ def _binary_compare_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @register_from_yaml_handler("Between") -def _between_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _between_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] return args[0].between(args[1], args[2]) @register_from_yaml_handler("Greater", "Less") -def _boolean_ops_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict["args"]] +def _boolean_ops_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] op_name = yaml_dict["op"] op_map = { "Greater": lambda left, right: left > right, @@ -1086,30 +1096,30 @@ def _boolean_ops_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @register_from_yaml_handler("And") -def _boolean_and_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", [])] +def _boolean_and_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict.get("args", [])] if not args: raise ValueError("And operator requires at least one argument") return functools.reduce(lambda x, y: x & y, args) @register_from_yaml_handler("Or") -def _boolean_or_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - args = [translate_from_yaml(arg, compiler) for arg in yaml_dict.get("args", [])] +def _boolean_or_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + args = [translate_from_yaml(arg, context) for arg in yaml_dict.get("args", [])] if not args: raise ValueError("Or operator requires at least one argument") return functools.reduce(lambda x, y: x | y, args) @register_from_yaml_handler("Not") -def _not_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _not_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return ~arg @register_from_yaml_handler("IsNull") -def _is_null_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _is_null_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return arg.isnull() @@ -1121,8 +1131,8 @@ def _is_null_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: "ExtractMinute", "ExtractSecond", ) -def _extract_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _extract_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) op_map = { "ExtractYear": lambda x: x.year(), "ExtractMonth": lambda x: x.month(), @@ -1135,16 +1145,18 @@ def _extract_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @register_from_yaml_handler("TimestampDiff") -def _timestamp_diff_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - left = translate_from_yaml(yaml_dict["args"][0], compiler) - right = translate_from_yaml(yaml_dict["args"][1], compiler) +def _timestamp_diff_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + left = translate_from_yaml(yaml_dict["args"][0], context) + right = translate_from_yaml(yaml_dict["args"][1], context) return left - right @register_from_yaml_handler("TimestampAdd", "TimestampSub") -def _timestamp_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - timestamp = translate_from_yaml(yaml_dict["args"][0], compiler) - interval = translate_from_yaml(yaml_dict["args"][1], compiler) +def _timestamp_arithmetic_from_yaml( + yaml_dict: dict, context: TranslationContext +) -> ir.Expr: + timestamp = translate_from_yaml(yaml_dict["args"][0], context) + interval = translate_from_yaml(yaml_dict["args"][1], context) if yaml_dict["op"] == "TimestampAdd": return timestamp + interval else: @@ -1152,29 +1164,29 @@ def _timestamp_arithmetic_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: @register_from_yaml_handler("Cast") -def _cast_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _cast_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) target_dtype = _type_from_yaml(yaml_dict["type"]) return arg.cast(target_dtype) @register_from_yaml_handler("CountStar") -def _count_star_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: - arg = translate_from_yaml(yaml_dict["args"][0], compiler) +def _count_star_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: + arg = translate_from_yaml(yaml_dict["args"][0], context) return ops.CountStar(arg).to_expr() @register_from_yaml_handler("StringSQLLike") -def _string_sql_like_from_yaml(yaml_dict: dict, compiler: Any) -> ir.Expr: +def _string_sql_like_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: args = yaml_dict.get("args", []) if not args: raise ValueError("Missing arguments for StringSQLLike operator") - col = translate_from_yaml(args[0], compiler) + col = translate_from_yaml(args[0], context) if len(args) >= 2: - pattern_expr = translate_from_yaml(args[1], compiler) + pattern_expr = translate_from_yaml(args[1], context) else: pattern_value = args[0].get("value") if pattern_value is None: From 5ba68572720cb6620a448717dfc594e41bea0ef8 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 23 Feb 2025 11:06:56 -0500 Subject: [PATCH 36/45] fix(tests): expr_hash is not stable due to pins --- python/xorq/ibis_yaml/tests/test_compiler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 410a92cf..8f38dd31 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -79,8 +79,7 @@ def test_ibis_compiler_parquet_reader(build_dir): ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) - compiler.compile_expr(expr) - expr_hash = "9a7d0b20d41a" + expr_hash = compiler.compile_expr(expr) roundtrip_expr = compiler.load_expr(expr_hash) assert expr.execute().equals(roundtrip_expr.execute()) @@ -96,8 +95,7 @@ def test_compiler_sql(build_dir): expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) - compiler.compile_expr(expr) - expr_hash = "79d83e9c89ad" + expr_hash = compiler.compile_expr(expr) _roundtrip_expr = compiler.load_expr(expr_hash) assert os.path.exists(build_dir / expr_hash / "sql.yaml") @@ -129,8 +127,7 @@ def test_deferred_reads_yaml(build_dir): expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") compiler = BuildManager(build_dir) - compiler.compile_expr(expr) - expr_hash = "79d83e9c89ad" + expr_hash = compiler.compile_expr(expr) _roundtrip_expr = compiler.load_expr(expr_hash) assert os.path.exists(build_dir / expr_hash / "deferred_reads.yaml") From 69722999cc650bc691bc759368c4dd69665e0fc8 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 23 Feb 2025 12:35:49 -0500 Subject: [PATCH 37/45] fix: make compiler test dynamic --- python/xorq/ibis_yaml/tests/test_compiler.py | 38 ++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 8f38dd31..4977a76d 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -1,12 +1,15 @@ import os import pathlib +import dask import pytest import yaml import xorq as xo +import xorq.vendor.ibis as ibis from xorq.common.utils.defer_utils import deferred_read_parquet from xorq.ibis_yaml.compiler import ArtifactStore, BuildManager +from xorq.ibis_yaml.sql import find_relations from xorq.vendor.ibis.common.collections import FrozenOrderedDict @@ -97,6 +100,8 @@ def test_compiler_sql(build_dir): compiler = BuildManager(build_dir) expr_hash = compiler.compile_expr(expr) _roundtrip_expr = compiler.load_expr(expr_hash) + expected_relation = find_relations(awards_players)[0] + expted_sql_hash = dask.base.tokenize(str(ibis.to_sql(expr)))[:12] assert os.path.exists(build_dir / expr_hash / "sql.yaml") assert os.path.exists(build_dir / expr_hash / "metadata.json") @@ -110,44 +115,57 @@ def test_compiler_sql(build_dir): " engine: let\n" f" profile_name: {expr._find_backend()._profile.hash_name}\n" " relations:\n" - " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" + f" - {expected_relation}\n" " options: {}\n" - " sql_file: df34d95d62bc.sql\n" + f" sql_file: {expted_sql_hash}.sql\n" ) assert sql_text == expected_result def test_deferred_reads_yaml(build_dir): backend = xo.datafusion.connect() + # Factor out the config path + config_path = xo.config.options.pins.get_path("awards_players") awards_players = deferred_read_parquet( backend, - xo.config.options.pins.get_path("awards_players"), + config_path, table_name="awards_players", ) expr = awards_players.filter(awards_players.lgID == "NL").drop("yearID", "lgID") + # Get the dynamic relation and profile hash + expected_relation = find_relations(awards_players)[0] + expected_profile = backend._profile.hash_name + compiler = BuildManager(build_dir) expr_hash = compiler.compile_expr(expr) _roundtrip_expr = compiler.load_expr(expr_hash) - assert os.path.exists(build_dir / expr_hash / "deferred_reads.yaml") - sql_text = pathlib.Path(build_dir / expr_hash / "deferred_reads.yaml").read_text() + yaml_path = build_dir / expr_hash / "deferred_reads.yaml" + assert os.path.exists(yaml_path) + sql_text = pathlib.Path(yaml_path).read_text() + + sql_str = str(ibis.to_sql(awards_players)) + expected_sql_file = dask.base.tokenize(sql_str)[:12] + ".sql" + + expected_read_path = str(config_path) expected_result = ( "reads:\n" - " awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f:\n" + f" {expected_relation}:\n" " engine: datafusion\n" - " profile_name: 30174be6bf62a829d7e62af391fc53b2\n" + f" profile_name: {expected_profile}\n" " relations:\n" - " - awards_players-eaf5fdf4554ae9098af6c7e7dfea1a9f\n" + f" - {expected_relation}\n" " options:\n" " method_name: read_parquet\n" " name: awards_players\n" " read_kwargs:\n" - " - path: /home/hussainsultan/.cache/pins-py/gs_d3037fb8920d01eb3b262ab08d52335c89ba62aa41299e5236f01807aa8b726d/awards_players/20240711T171119Z-886c4/awards_players.parquet\n" + f" - path: {expected_read_path}\n" " - table_name: awards_players\n" - " sql_file: c0907dab80b0.sql\n" + f" sql_file: {expected_sql_file}\n" ) + assert sql_text == expected_result From 913db61b39243c4d8e22cdd2fd3708b368373e9a Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 05:43:58 -0500 Subject: [PATCH 38/45] cln --- python/letsql/common/utils/graph_utils.py | 52 ----------------------- 1 file changed, 52 deletions(-) delete mode 100644 python/letsql/common/utils/graph_utils.py diff --git a/python/letsql/common/utils/graph_utils.py b/python/letsql/common/utils/graph_utils.py deleted file mode 100644 index 6908c24d..00000000 --- a/python/letsql/common/utils/graph_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -import xorq.expr.relations as rel - - -def walk_nodes(node_types, expr): - def process_node(op): - match op: - case rel.RemoteTable(): - yield op - yield from walk_nodes( - node_types, - op.remote_expr, - ) - case rel.CachedNode(): - yield op - yield from walk_nodes( - node_types, - op.parent, - ) - case _: - yield from op.find(node_types) - - def inner(rest, seen): - if not rest: - return seen - op = rest.pop() - seen.add(op) - new = process_node(op) - rest.update(set(new).difference(seen)) - return inner(rest, seen) - - rest = process_node(expr.op()) - return inner(set(rest), set()) - - -def find_all_sources(expr): - import xorq.vendor.ibis.expr.operations as ops - - node_types = ( - ops.DatabaseTable, - ops.SQLQueryResult, - rel.CachedNode, - rel.Read, - rel.RemoteTable, - # ExprScalarUDF has an expr we need to get to - # FlightOperator has a dynamically generated connection: it should be passed a Profile instead - ) - nodes = walk_nodes(node_types, expr) - sources = tuple( - source - for (source, _) in set((node.source, node.source._profile) for node in nodes) - ) - return sources From 9f5c174033c8f7f40acf2d9bf2cdff97f7d2ecdd Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 05:49:13 -0500 Subject: [PATCH 39/45] cln(pyproject): remove hypothesis --- pyproject.toml | 1 - requirements-dev.txt | 3 +-- uv.lock | 17 ----------------- 3 files changed, 1 insertion(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bfb9f909..bef1c05e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies = [ "sqlglot==25.20.2", "toolz>=0.11", "typing-extensions>=4.3.0", - "hypothesis>=6.124.9", "pyyaml>=6.0.2", "cloudpickle>=3.1.1", ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 27552a4f..339c7efa 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,7 +59,6 @@ googleapis-common-protos==1.68.0 ; python_full_version < '4.0' greenlet==3.1.1 ; (python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64') griffe==1.5.7 humanize==4.12.1 ; python_full_version < '4.0' -hypothesis==6.126.0 identify==2.6.7 idna==3.10 importlib-metadata==8.6.1 @@ -148,7 +147,7 @@ scipy==1.15.2 setuptools==75.8.0 ; sys_platform == 'darwin' six==1.17.0 snowflake-connector-python==3.13.2 ; python_full_version < '4.0' -sortedcontainers==2.4.0 +sortedcontainers==2.4.0 ; python_full_version < '4.0' sphobjinv==2.3.1.2 sqlalchemy==2.0.38 ; python_full_version < '4.0' sqlglot==25.20.2 diff --git a/uv.lock b/uv.lock index 12a71050..54e1bbb3 100644 --- a/uv.lock +++ b/uv.lock @@ -1144,20 +1144,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/30/5ef5994b090398f9284d2662f56853e5183ae2cb5d8e3db67e4f4cfea407/humanize-4.12.1-py3-none-any.whl", hash = "sha256:86014ca5c52675dffa1d404491952f1f5bf03b07c175a51891a343daebf01fea", size = 127409 }, ] -[[package]] -name = "hypothesis" -version = "6.126.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "attrs" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "sortedcontainers" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a5/8c/8281dd1408dd8374b0ed0528e63fb53a556b3d4f901382f51148345ec9fb/hypothesis-6.126.0.tar.gz", hash = "sha256:648b6215ee0468fa85eaee9dceb5b7766a5861c20ee4801bd904a2c02f1a6c9b", size = 420895 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/fc/8e1749aa79631952bf70e57913626fb0a0556eb2ad3c530c5526f6e5ba13/hypothesis-6.126.0-py3-none-any.whl", hash = "sha256:323c58a773482a2b4ba4e35202560cfcba45e8a8e09e7ffb83c0f9bac5b544da", size = 483657 }, -] - [[package]] name = "identify" version = "2.6.7" @@ -3345,7 +3331,6 @@ dependencies = [ { name = "connectorx", marker = "python_full_version < '4.0'" }, { name = "dask", marker = "python_full_version < '4.0'" }, { name = "geoarrow-types", marker = "python_full_version < '4.0'" }, - { name = "hypothesis" }, { name = "pandas", marker = "python_full_version < '4.0'" }, { name = "parsy" }, { name = "psycopg2-binary", marker = "python_full_version < '4.0'" }, @@ -3434,7 +3419,6 @@ requires-dist = [ { name = "duckdb", marker = "extra == 'duckdb'", specifier = ">=1.1.3" }, { name = "fsspec", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=2024.6.1,<2025.2.1" }, { name = "geoarrow-types", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=0.2,<1" }, - { name = "hypothesis", specifier = ">=6.124.9" }, { name = "pandas", marker = "python_full_version >= '3.10' and python_full_version < '4.0'", specifier = ">=1.5.3,<3" }, { name = "parsy", specifier = ">=2" }, { name = "pins", extras = ["gcs"], marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=0.8.3,<1" }, @@ -3457,7 +3441,6 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.3.0" }, { name = "xgboost", marker = "python_full_version >= '3.10' and python_full_version < '4.0' and extra == 'examples'", specifier = ">=1.6.1" }, ] -provides-extras = ["duckdb", "datafusion", "snowflake", "quickgrove", "examples"] [package.metadata.requires-dev] dev = [ From b484b03ed60a212cd9680085034d041e7c2008aa Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 06:54:31 -0500 Subject: [PATCH 40/45] ref(ibis-yaml): use operator where possible --- python/xorq/ibis_yaml/translate.py | 41 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/xorq/ibis_yaml/translate.py b/python/xorq/ibis_yaml/translate.py index 289eec67..5107e148 100644 --- a/python/xorq/ibis_yaml/translate.py +++ b/python/xorq/ibis_yaml/translate.py @@ -3,6 +3,7 @@ import datetime import decimal import functools +import operator from typing import Any import xorq.vendor.ibis as ibis @@ -1000,10 +1001,10 @@ def _binary_arithmetic_from_yaml( left = translate_from_yaml(yaml_dict["args"][0], context) right = translate_from_yaml(yaml_dict["args"][1], context) op_map = { - "Add": lambda left, right: left + right, - "Subtract": lambda left, right: left - right, - "Multiply": lambda left, right: left * right, - "Divide": lambda left, right: left / right, + "Add": operator.add, + "Subtract": operator.sub, + "Multiply": operator.mul, + "Divide": operator.truediv, } op_func = op_map.get(yaml_dict["op"]) if op_func is None: @@ -1064,12 +1065,12 @@ def _binary_compare_from_yaml(yaml_dict: dict, context: TranslationContext) -> i right = translate_from_yaml(yaml_dict["args"][1], context) op_map = { - "Equals": lambda left, right: left == right, - "NotEquals": lambda left, right: left != right, - "GreaterThan": lambda left, right: left > right, - "GreaterEqual": lambda left, right: left >= right, - "LessThan": lambda left, right: left < right, - "LessEqual": lambda left, right: left <= right, + "Equals": operator.eq, + "NotEquals": operator.ne, + "GreaterThan": operator.gt, + "GreaterEqual": operator.ge, + "LessThan": operator.lt, + "LessEqual": operator.le, } op_func = op_map.get(yaml_dict["op"]) @@ -1089,8 +1090,8 @@ def _boolean_ops_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.E args = [translate_from_yaml(arg, context) for arg in yaml_dict["args"]] op_name = yaml_dict["op"] op_map = { - "Greater": lambda left, right: left > right, - "Less": lambda left, right: left < right, + "Greater": operator.gt, + "Less": operator.lt, } return op_map[op_name](*args) @@ -1100,7 +1101,7 @@ def _boolean_and_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.E args = [translate_from_yaml(arg, context) for arg in yaml_dict.get("args", [])] if not args: raise ValueError("And operator requires at least one argument") - return functools.reduce(lambda x, y: x & y, args) + return functools.reduce(operator.and_, args) @register_from_yaml_handler("Or") @@ -1108,7 +1109,7 @@ def _boolean_or_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Ex args = [translate_from_yaml(arg, context) for arg in yaml_dict.get("args", [])] if not args: raise ValueError("Or operator requires at least one argument") - return functools.reduce(lambda x, y: x | y, args) + return functools.reduce(operator.or_, args) @register_from_yaml_handler("Not") @@ -1134,12 +1135,12 @@ def _is_null_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: def _extract_from_yaml(yaml_dict: dict, context: TranslationContext) -> ir.Expr: arg = translate_from_yaml(yaml_dict["args"][0], context) op_map = { - "ExtractYear": lambda x: x.year(), - "ExtractMonth": lambda x: x.month(), - "ExtractDay": lambda x: x.day(), - "ExtractHour": lambda x: x.hour(), - "ExtractMinute": lambda x: x.minute(), - "ExtractSecond": lambda x: x.second(), + "ExtractYear": operator.methodcaller("year"), + "ExtractMonth": operator.methodcaller("month"), + "ExtractDay": operator.methodcaller("day"), + "ExtractHour": operator.methodcaller("hour"), + "ExtractMinute": operator.methodcaller("minute"), + "ExtractSecond": operator.methodcaller("second"), } return op_map[yaml_dict["op"]](arg) From d403f0ac0fe2f05fe1678f4dc121ca8472cbc05c Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 07:32:53 -0500 Subject: [PATCH 41/45] ref(ibis-yaml): use config.hash_length --- python/xorq/ibis_yaml/tests/test_compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/xorq/ibis_yaml/tests/test_compiler.py b/python/xorq/ibis_yaml/tests/test_compiler.py index 4977a76d..c0a596b4 100644 --- a/python/xorq/ibis_yaml/tests/test_compiler.py +++ b/python/xorq/ibis_yaml/tests/test_compiler.py @@ -9,6 +9,7 @@ import xorq.vendor.ibis as ibis from xorq.common.utils.defer_utils import deferred_read_parquet from xorq.ibis_yaml.compiler import ArtifactStore, BuildManager +from xorq.ibis_yaml.config import config from xorq.ibis_yaml.sql import find_relations from xorq.vendor.ibis.common.collections import FrozenOrderedDict @@ -101,7 +102,7 @@ def test_compiler_sql(build_dir): expr_hash = compiler.compile_expr(expr) _roundtrip_expr = compiler.load_expr(expr_hash) expected_relation = find_relations(awards_players)[0] - expted_sql_hash = dask.base.tokenize(str(ibis.to_sql(expr)))[:12] + expted_sql_hash = dask.base.tokenize(str(ibis.to_sql(expr)))[: config.hash_length] assert os.path.exists(build_dir / expr_hash / "sql.yaml") assert os.path.exists(build_dir / expr_hash / "metadata.json") @@ -146,7 +147,7 @@ def test_deferred_reads_yaml(build_dir): sql_text = pathlib.Path(yaml_path).read_text() sql_str = str(ibis.to_sql(awards_players)) - expected_sql_file = dask.base.tokenize(sql_str)[:12] + ".sql" + expected_sql_file = dask.base.tokenize(sql_str)[: config.hash_length] + ".sql" expected_read_path = str(config_path) From 466f65b4188b1cbeab9ed7dab52bd8b1faf33315 Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 07:35:09 -0500 Subject: [PATCH 42/45] cln(ibis-yaml): remove check, path.open will raise --- python/xorq/ibis_yaml/compiler.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index 904de050..f97837cc 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -65,15 +65,11 @@ def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: def read_yaml(self, *path_parts) -> Dict[str, Any]: path = self.get_path(*path_parts) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") with path.open("r") as f: return yaml.safe_load(f) def read_json(self, *path_parts) -> Dict[str, Any]: path = self.get_path(*path_parts) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") with path.open("r") as f: return json.load(f) @@ -86,8 +82,6 @@ def write_text(self, content: str, *path_parts) -> pathlib.Path: def read_text(self, *path_parts) -> str: path = self.get_path(*path_parts) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") with path.open("r") as f: return f.read() From 4c16d5023d76fc673a8d8bb7ffb48e368266d084 Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 07:59:50 -0500 Subject: [PATCH 43/45] ref(ibis-yaml): defined ArtifactStore.{_read,_write} --- python/xorq/ibis_yaml/compiler.py | 42 +++++++++++++++++-------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/python/xorq/ibis_yaml/compiler.py b/python/xorq/ibis_yaml/compiler.py index f97837cc..1de8fed7 100644 --- a/python/xorq/ibis_yaml/compiler.py +++ b/python/xorq/ibis_yaml/compiler.py @@ -1,4 +1,6 @@ +import contextlib import json +import operator import pathlib from pathlib import Path from typing import Any, Dict @@ -50,10 +52,29 @@ def ensure_dir(self, *parts) -> pathlib.Path: path.mkdir(parents=True, exist_ok=True) return path - def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: + def _read(self, read_f, *parts): + path = self.get_path(*parts) + with path.open("r") as f: + return read_f(f) + + def read_yaml(self, *path_parts) -> Dict[str, Any]: + return self._read(yaml.safe_load, *path_parts) + + def read_json(self, *path_parts) -> Dict[str, Any]: + return self._read(json.load, *path_parts) + + def read_text(self, *path_parts) -> str: + return self._read(operator.methodcaller("read"), *path_parts) + + @contextlib.contextmanager + def _write(self, *path_parts): path = self.get_path(*path_parts) path.parent.mkdir(parents=True, exist_ok=True) with path.open("w") as f: + yield (path, f) + + def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: + with self._write(*path_parts) as (path, f): yaml.dump( data, f, @@ -63,28 +84,11 @@ def write_yaml(self, data: Dict[str, Any], *path_parts) -> pathlib.Path: ) return path - def read_yaml(self, *path_parts) -> Dict[str, Any]: - path = self.get_path(*path_parts) - with path.open("r") as f: - return yaml.safe_load(f) - - def read_json(self, *path_parts) -> Dict[str, Any]: - path = self.get_path(*path_parts) - with path.open("r") as f: - return json.load(f) - def write_text(self, content: str, *path_parts) -> pathlib.Path: - path = self.get_path(*path_parts) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w") as f: + with self._write(*path_parts) as (path, f): f.write(content) return path - def read_text(self, *path_parts) -> str: - path = self.get_path(*path_parts) - with path.open("r") as f: - return f.read() - def exists(self, *path_parts) -> bool: return self.get_path(*path_parts).exists() From 58f6310d2ee7760d99f1447142298d18e68468c5 Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 09:58:32 -0500 Subject: [PATCH 44/45] ref: use itertools.count --- python/xorq/ibis_yaml/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/xorq/ibis_yaml/common.py b/python/xorq/ibis_yaml/common.py index 374426b6..f328b44d 100644 --- a/python/xorq/ibis_yaml/common.py +++ b/python/xorq/ibis_yaml/common.py @@ -1,4 +1,5 @@ import functools +import itertools from typing import Any import attr @@ -15,7 +16,7 @@ class SchemaRegistry: def __init__(self): self.schemas = {} - self.counter = 0 + self.counter = itertools.count() def register_schema(self, schema): frozen_schema = freeze( @@ -26,9 +27,8 @@ def register_schema(self, schema): if existing_schema == frozen_schema: return schema_id - schema_id = f"schema_{self.counter}" + schema_id = f"schema_{next(self.counter)}" self.schemas[schema_id] = frozen_schema - self.counter += 1 return schema_id def _register_expr_schema(self, expr: ir.Expr) -> str: From 7a401c2ea023bd81b992708ae97a326c35c943f5 Mon Sep 17 00:00:00 2001 From: dlovell Date: Mon, 24 Feb 2025 10:26:36 -0500 Subject: [PATCH 45/45] ref --- python/xorq/ibis_yaml/sql.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/xorq/ibis_yaml/sql.py b/python/xorq/ibis_yaml/sql.py index 73c4efdc..90ad7e60 100644 --- a/python/xorq/ibis_yaml/sql.py +++ b/python/xorq/ibis_yaml/sql.py @@ -22,11 +22,7 @@ class DeferredReadsPlan(TypedDict): def find_relations(expr: ir.Expr) -> List[str]: - node_types = (RemoteTable, Read, ops.DatabaseTable) - nodes = walk_nodes(node_types, expr) - relations = [] - seen = set() - for node in nodes: + def get_name(node): name = None if isinstance(node, RemoteTable): name = node.name @@ -34,9 +30,11 @@ def find_relations(expr: ir.Expr) -> List[str]: name = node.make_unbound_dt().name elif isinstance(node, ops.DatabaseTable): name = node.name - if name and name not in seen: - seen.add(name) - relations.append(name) + return name + + node_types = (RemoteTable, Read, ops.DatabaseTable) + nodes = walk_nodes(node_types, expr) + relations = list(set(filter(None, map(get_name, nodes)))) return relations