From 42bbedbc71ea6f89c39373f5469f0916cc47aa36 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 5 Jan 2023 12:57:33 +0800 Subject: [PATCH 1/2] bugfix: remove cnf_rewrite in push_down_filter --- benchmarks/expected-plans/q19.txt | 13 ++-- benchmarks/expected-plans/q7.txt | 2 +- datafusion/core/tests/sql/joins.rs | 2 +- datafusion/core/tests/sql/predicates.rs | 6 +- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 66 ++++++++++++++------ 6 files changed, 58 insertions(+), 33 deletions(-) diff --git a/benchmarks/expected-plans/q19.txt b/benchmarks/expected-plans/q19.txt index 3efc3718d01a..969ad02d4c59 100644 --- a/benchmarks/expected-plans/q19.txt +++ b/benchmarks/expected-plans/q19.txt @@ -1,9 +1,8 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue Aggregate: groupBy=[[]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] - Projection: lineitem.l_extendedprice, lineitem.l_discount - Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) - Inner Join: lineitem.l_partkey = part.p_partkey - Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") - TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] - Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) - TableScan: part projection=[p_partkey, p_brand, p_size, p_container] + Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) + Inner Join: lineitem.l_partkey = part.p_partkey + Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") + TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] + Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) + TableScan: part projection=[p_partkey, p_brand, p_size, p_container] diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index 53deda1b87c0..bd8c10f8cf88 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] SubqueryAlias: shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume - Filter: (n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE")) AND (n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")) + Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 4cc9628d15ae..1de20c29cd07 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1561,7 +1561,7 @@ async fn reduce_left_join_2() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 61d509a2ddc8..1e8888ce45f9 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -590,10 +590,8 @@ async fn multiple_or_predicates() -> Result<()> { " Projection: lineitem.l_partkey [l_partkey:Int64]", " Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Projection: lineitem.l_partkey, lineitem.l_quantity [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Projection: lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(3000),15,2) AS lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_partkey, lineitem.l_quantity [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", ]; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 3e7df55d5b1b..7390ac2041e7 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -215,12 +215,12 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), + Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(PropagateEmptyRelation::new()), - Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after LimitPushDown diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ff0ea5b23bef..35c5dacfa28f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -14,7 +14,7 @@ //! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan -use crate::utils::conjunction; +use crate::utils::{conjunction, split_conjunction}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -421,7 +421,7 @@ fn push_down_join( ) -> Result> { let mut predicates = match parent_predicate { Some(parent_predicate) => { - utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone())) + utils::split_conjunction_owned(parent_predicate.clone()) } None => vec![], }; @@ -538,8 +538,21 @@ impl OptimizerRule for PushDownFilter { let child_plan = filter.input.as_ref(); let new_plan = match child_plan { LogicalPlan::Filter(child_filter) => { - let new_predicate = - and(filter.predicate.clone(), child_filter.predicate.clone()); + let parents_predicates = split_conjunction(&filter.predicate); + let set: HashSet<&&Expr> = parents_predicates.iter().collect(); + + let new_predicates = parents_predicates + .iter() + .chain( + split_conjunction(&child_filter.predicate) + .iter() + .filter(|e| !set.contains(e)), + ) + .map(|e| (*e).clone()) + .collect::>(); + let new_predicate = conjunction(new_predicates).ok_or( + DataFusionError::Plan("at least one expression exists".to_string()), + )?; let new_plan = LogicalPlan::Filter(Filter::try_new( new_predicate, child_filter.input.clone(), @@ -638,9 +651,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) .collect::>>()?; - let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( - filter.predicate.clone(), - )); + let predicates = utils::split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -689,19 +700,15 @@ impl OptimizerRule for PushDownFilter { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( - filter.predicate.clone(), - )); - + let predicates = utils::split_conjunction_owned(filter.predicate.clone()); push_down_all_join(predicates, &filter.input, left, right, vec![])? } LogicalPlan::TableScan(scan) => { let mut new_scan_filters = scan.filters.clone(); let mut new_predicate = vec![]; - let filter_predicates = utils::split_conjunction_owned( - utils::cnf_rewrite(filter.predicate.clone()), - ); + let filter_predicates = + utils::split_conjunction_owned(filter.predicate.clone()); for filter_expr in &filter_predicates { let (preserve_filter_node, add_to_provider) = @@ -754,7 +761,10 @@ impl PushDownFilter { } /// replaces columns by its name on the projection. -fn replace_cols_by_name(e: Expr, replace_map: &HashMap) -> Result { +pub fn replace_cols_by_name( + e: Expr, + replace_map: &HashMap, +) -> Result { struct ColumnReplacer<'a> { replace_map: &'a HashMap, } @@ -778,6 +788,7 @@ fn replace_cols_by_name(e: Expr, replace_map: &HashMap) -> Result< #[cfg(test)] mod tests { use super::*; + use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -801,6 +812,24 @@ mod tests { Ok(()) } + fn assert_optimized_plan_eq_with_rewrite_predicate( + plan: &LogicalPlan, + expected: &str, + ) -> Result<()> { + let mut optimized_plan = RewriteDisjunctivePredicate::new() + .try_optimize(plan, &OptimizerContext::new()) + .unwrap() + .expect("failed to optimize plan"); + optimized_plan = PushDownFilter::new() + .try_optimize(&optimized_plan, &OptimizerContext::new()) + .unwrap() + .expect("failed to optimize plan"); + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(expected, formatted_plan); + Ok(()) + } + #[test] fn filter_before_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -2281,20 +2310,19 @@ mod tests { .build()?; let expected = "\ - Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\ + Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ \n TableScan: test\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() + .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); assert_optimized_plan_eq(&optimized_plan, expected) } From 83bcaea63b2d8e68f0895c052866920d61848d19 Mon Sep 17 00:00:00 2001 From: jackwener Date: Fri, 6 Jan 2023 12:31:28 +0800 Subject: [PATCH 2/2] remove cnf_rewrite() and related tests. --- datafusion/optimizer/src/utils.rs | 233 +----------------------------- 1 file changed, 2 insertions(+), 231 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d9c7477cbb8a..ed0f07f5503e 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ utils::from_plan, Expr, Operator, }; -use std::collections::{HashSet, VecDeque}; +use std::collections::HashSet; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -170,104 +170,6 @@ fn split_binary_impl<'a>( } } -/// Given a list of lists of [`Expr`]s, returns a list of lists of -/// [`Expr`]s of expressions where there is one expression from each -/// from each of the input expressions -/// -/// For example, given the input `[[a, b], [c], [d, e]]` returns -/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`. -fn permutations(mut exprs: VecDeque>) -> Vec> { - let first = if let Some(first) = exprs.pop_front() { - first - } else { - return vec![]; - }; - - // base case: - if exprs.is_empty() { - first.into_iter().map(|e| vec![e]).collect() - } else { - first - .into_iter() - .flat_map(|expr| { - permutations(exprs.clone()) - .into_iter() - .map(|expr_list| { - // Create [expr, ...] for each permutation - std::iter::once(expr) - .chain(expr_list.into_iter()) - .collect::>() - }) - .collect::>>() - }) - .collect() - } -} - -const MAX_CNF_REWRITE_CONJUNCTS: usize = 10; - -/// Tries to convert an expression to conjunctive normal form (CNF). -/// -/// Does not convert the expression if the total number of conjuncts -/// (exprs ANDed together) would exceed [`MAX_CNF_REWRITE_CONJUNCTS`]. -/// -/// The following expression is in CNF: -/// `(a OR b) AND (c OR d)` -/// -/// The following is not in CNF: -/// `(a AND b) OR c`. -/// -/// But could be rewrite to a CNF expression: -/// `(a OR c) AND (b OR c)`. -/// -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_optimizer::utils::cnf_rewrite; -/// // (a=1 AND b=2)OR c = 3 -/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); -/// let expr2 = col("c").eq(lit(3)); -/// let expr = expr1.or(expr2); -/// -/// //(a=1 or c=3)AND(b=2 or c=3) -/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); -/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); -/// let expect = expr1.and(expr2); -/// assert_eq!(expect, cnf_rewrite(expr)); -/// ``` -pub fn cnf_rewrite(expr: Expr) -> Expr { - // Find all exprs joined by OR - let disjuncts = split_binary(&expr, Operator::Or); - - // For each expr, split now on AND - // A OR B OR C --> split each A, B and C - let disjunct_conjuncts: VecDeque> = disjuncts - .into_iter() - .map(|e| split_binary(e, Operator::And)) - .collect::>(); - - // Decide if we want to distribute the clauses. Heuristic is - // chosen to avoid creating huge predicates - let num_conjuncts = disjunct_conjuncts - .iter() - .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len())); - - if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) - && num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS - { - let or_clauses = permutations(disjunct_conjuncts) - .into_iter() - // form the OR clauses( A OR B OR C ..) - .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap()); - conjunction(or_clauses).unwrap() - } - // otherwise return the original expression - else { - expr - } -} - /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -614,7 +516,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, or, utils::expr_to_columns}; + use datafusion_expr::{col, lit, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -815,135 +717,4 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } - - #[test] - fn test_permutations() { - assert_eq!(make_permutations(vec![]), vec![] as Vec>) - } - - #[test] - fn test_permutations_one() { - // [[a]] --> [[a]] - assert_eq!( - make_permutations(vec![vec![col("a")]]), - vec![vec![col("a")]] - ) - } - - #[test] - fn test_permutations_two() { - // [[a, b]] --> [[a], [b]] - assert_eq!( - make_permutations(vec![vec![col("a"), col("b")]]), - vec![vec![col("a")], vec![col("b")]] - ) - } - - #[test] - fn test_permutations_two_and_one() { - // [[a, b], [c]] --> [[a, c], [b, c]] - assert_eq!( - make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]), - vec![vec![col("a"), col("c")], vec![col("b"), col("c")]] - ) - } - - #[test] - fn test_permutations_two_and_one_and_two() { - // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c, e]] - assert_eq!( - make_permutations(vec![ - vec![col("a"), col("b")], - vec![col("c")], - vec![col("d"), col("e")] - ]), - vec![ - vec![col("a"), col("c"), col("d")], - vec![col("a"), col("c"), col("e")], - vec![col("b"), col("c"), col("d")], - vec![col("b"), col("c"), col("e")], - ] - ) - } - - /// call permutations with owned `Expr`s for easier testing - fn make_permutations(exprs: impl IntoIterator>) -> Vec> { - let exprs = exprs.into_iter().collect::>(); - - let exprs: VecDeque> = exprs - .iter() - .map(|exprs| exprs.iter().collect::>()) - .collect(); - - permutations(exprs) - .into_iter() - // copy &Expr --> Expr - .map(|exprs| exprs.into_iter().cloned().collect()) - .collect() - } - - #[test] - fn test_rewrite_cnf() { - let a_1 = col("a").eq(lit(1i64)); - let a_2 = col("a").eq(lit(2i64)); - - let b_1 = col("b").eq(lit(1i64)); - let b_2 = col("b").eq(lit(2i64)); - - // Test rewrite on a1_and_b2 and a2_and_b1 -> not change - let expr1 = and(a_1.clone(), b_2.clone()); - let expect = expr1.clone(); - assert_eq!(expect, cnf_rewrite(expr1)); - - // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) - let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let expect = and(a_1.clone(), b_2.clone()) - .and(a_2.clone()) - .and(b_1.clone()); - assert_eq!(expect, cnf_rewrite(expr1)); - - // Test rewrite on a1_or_b2 -> not change - let expr1 = or(a_1.clone(), b_2.clone()); - let expect = expr1.clone(); - assert_eq!(expect, cnf_rewrite(expr1)); - - // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 - let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let a1_or_a2 = or(a_1.clone(), a_2.clone()); - let a1_or_b1 = or(a_1.clone(), b_1.clone()); - let b2_or_a2 = or(b_2.clone(), a_2.clone()); - let b2_or_b1 = or(b_2.clone(), b_1.clone()); - let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); - assert_eq!(expect, cnf_rewrite(expr1)); - - // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) - let a1_or_b2 = or(a_1.clone(), b_2.clone()); - let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); - assert_eq!(expect, cnf_rewrite(expr1)); - - // Test rewrite on a1_or_b2 or a2_or_b1 -> not change - let expr1 = or(or(a_1, b_2), or(a_2, b_1)); - let expect = expr1.clone(); - assert_eq!(expect, cnf_rewrite(expr1)); - } - - #[test] - fn test_rewrite_cnf_overflow() { - // in this situation: - // AND = (a=1 and b=2) - // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) - // which cause size expansion. - - let mut expr1 = col("test1").eq(lit(1i64)); - let expr2 = col("test2").eq(lit(2i64)); - - for _i in 0..9 { - expr1 = expr1.clone().and(expr2.clone()); - } - let expr3 = expr1.clone(); - let expr = or(expr1, expr3); - - assert_eq!(expr, cnf_rewrite(expr.clone())); - } }