diff --git a/benchmarks/expected-plans/q11.txt b/benchmarks/expected-plans/q11.txt index e0ccdfce87c6..bb5493828e99 100644 --- a/benchmarks/expected-plans/q11.txt +++ b/benchmarks/expected-plans/q11.txt @@ -14,7 +14,7 @@ Sort: value DESC NULLS FIRST Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] + TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file diff --git a/benchmarks/expected-plans/q15.txt b/benchmarks/expected-plans/q15.txt index 1100d17b617d..96401dd7bd81 100644 --- a/benchmarks/expected-plans/q15.txt +++ b/benchmarks/expected-plans/q15.txt @@ -7,10 +7,9 @@ Sort: supplier.s_suppkey ASC NULLS LAST SubqueryAlias: revenue0 Projection: supplier_no, total_revenue Projection: lineitem.l_suppkey AS supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue - Projection: lineitem.l_suppkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) - Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] - Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") - TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate] + Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] + Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") + TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate] SubqueryAlias: __sq_1 Projection: MAX(revenue0.total_revenue) AS __value Aggregate: groupBy=[[]], aggr=[[MAX(revenue0.total_revenue)]] diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt index 845d79263d89..34fb1e09a2f0 100644 --- a/benchmarks/expected-plans/q2.txt +++ b/benchmarks/expected-plans/q2.txt @@ -20,7 +20,7 @@ Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplie Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] - TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + TableScan: supplier projection=[s_suppkey, s_nationkey] + TableScan: nation projection=[n_nationkey, n_regionkey] Filter: region.r_name = Utf8("EUROPE") TableScan: region projection=[r_regionkey, r_name] \ No newline at end of file diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index bcbbead561b9..93ec057abc8a 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -64,7 +64,7 @@ async fn explain_analyze_baseline_metrics() { ); assert_metrics!( &formatted, - "SortExec: [c1@1 ASC NULLS LAST]", + "SortExec: [c1@0 ASC NULLS LAST]", "metrics=[output_rows=5, elapsed_compute=" ); assert_metrics!( @@ -573,7 +573,7 @@ async fn csv_explain_verbose_plans() { // Since the plan contains path that are environmentally // dependant(e.g. full path of the test file), only verify // important content - assert_contains!(&actual, "logical_plan after projection_push_down"); + assert_contains!(&actual, "logical_plan after push_down_projection"); assert_contains!(&actual, "physical_plan"); assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); @@ -744,7 +744,7 @@ async fn test_physical_plan_display_indent_multi_children() { " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 9000)", " ProjectionExec: expr=[c1@0 as c2]", " RepartitionExec: partitioning=RoundRobinBatch(9000)", - " CsvExec: files={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, has_header=true, limit=None, projection=[c1, c2]", + " CsvExec: files={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, has_header=true, limit=None, projection=[c1]", ]; let normalizer = ExplainNormalizer::new(); diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index c3659451a991..263244880fd2 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1635,13 +1635,14 @@ async fn reduce_left_join_3() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t3.t1_id, t3.t1_name, t3.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_int:UInt32;N]", + " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 21d46f7c8f71..0510b12637b2 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -549,8 +549,7 @@ async fn in_set_string_dictionaries() -> Result<()> { } #[tokio::test] -#[ignore] -// https://github.com/apache/arrow-datafusion/issues/3635 +// Test issue: https://github.com/apache/arrow-datafusion/issues/3635 async fn multiple_or_predicates() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "lineitem").await?; @@ -589,19 +588,15 @@ async fn multiple_or_predicates() -> Result<()> { // factored out and appear only once in the following plan let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #lineitem.l_partkey [l_partkey:Int64]", - " Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]", - " Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Filter: #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[#lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Filter: #part.p_size >= Int32(1) AND #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1), #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " Projection: lineitem.l_partkey [l_partkey:Int64]", - " Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Projection: lineitem.l_partkey, lineitem.l_quantity [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Projection: lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(3000),15,2) AS lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_partkey, lineitem.l_quantity [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 719c0c3d7a25..a1b22389faaa 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -162,8 +162,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ \n Inner Join: partsupp.ps_suppkey = supplier.s_suppkey\ \n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]\ - \n TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: nation projection=[n_nationkey, n_name, n_regionkey]\ + \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ + \n TableScan: nation projection=[n_nationkey, n_regionkey]\ \n Filter: region.r_name = Utf8(\"EUROPE\")\ \n TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8(\"EUROPE\")]"; assert_eq!(actual, expected); @@ -445,7 +445,7 @@ order by value desc; .map_err(|e| format!("{:?} at {}", e, "error")) .unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = "Sort: value DESC NULLS FIRST\ + let expected = "Sort: value DESC NULLS FIRST\ \n Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value\ \n Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))\ \n CrossJoin:\ @@ -461,7 +461,7 @@ order by value desc; \n Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]\ \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ \n Inner Join: partsupp.ps_suppkey = supplier.s_suppkey\ - \n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]\ + \n TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]\ \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ \n Filter: nation.n_name = Utf8(\"GERMANY\")\ \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"GERMANY\")]"; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 12603fe722bf..1bf7ad58bb95 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -25,10 +25,10 @@ pub mod eliminate_outer_join; pub mod filter_null_join_keys; pub mod inline_table_scan; pub mod optimizer; -pub mod projection_push_down; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; +pub mod push_down_projection; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 1383a47b186a..dd4783ceb3b1 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -26,10 +26,10 @@ use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::inline_table_scan::InlineTableScan; -use crate::projection_push_down::ProjectionPushDown; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; +use crate::push_down_projection::PushDownProjection; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; @@ -194,7 +194,7 @@ impl Optimizer { rules.push(Arc::new(SimplifyExpressions::new())); rules.push(Arc::new(UnwrapCastInComparison::new())); rules.push(Arc::new(CommonSubexprEliminate::new())); - rules.push(Arc::new(ProjectionPushDown::new())); + rules.push(Arc::new(PushDownProjection::new())); Self::with_rules(rules) } diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/push_down_projection.rs similarity index 97% rename from datafusion/optimizer/src/projection_push_down.rs rename to datafusion/optimizer/src/push_down_projection.rs index 1e54d7184ccc..3cedddc600d1 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -43,9 +43,9 @@ use std::{ /// Optimizer that removes unused projections and aggregations from plans /// This reduces both scans and #[derive(Default)] -pub struct ProjectionPushDown {} +pub struct PushDownProjection {} -impl OptimizerRule for ProjectionPushDown { +impl OptimizerRule for PushDownProjection { fn optimize( &self, plan: &LogicalPlan, @@ -62,11 +62,11 @@ impl OptimizerRule for ProjectionPushDown { } fn name(&self) -> &str { - "projection_push_down" + "push_down_projection" } } -impl ProjectionPushDown { +impl PushDownProjection { #[allow(missing_docs)] pub fn new() -> Self { Self {} @@ -75,7 +75,7 @@ impl ProjectionPushDown { /// Recursively transverses the logical plan removing expressions and that are not needed. fn optimize_plan( - _optimizer: &ProjectionPushDown, + _optimizer: &PushDownProjection, plan: &LogicalPlan, required_columns: &HashSet, // set of columns required up to this step has_projection: bool, @@ -94,23 +94,22 @@ fn optimize_plan( let mut new_expr = Vec::new(); let mut new_fields = Vec::new(); + // When meet projection, its expr must contain all columns that its child need. + // So we need create a empty required_columns instead use original new_required_columns. + // Otherwise it cause redundant columns. + let mut new_required_columns = HashSet::new(); // Gather all columns needed for expressions in this Projection - schema - .fields() - .iter() - .enumerate() - .try_for_each(|(i, field)| { - if required_columns.contains(&field.qualified_column()) { - new_expr.push(expr[i].clone()); - new_fields.push(field.clone()); + schema.fields().iter().enumerate().for_each(|(i, field)| { + if required_columns.contains(&field.qualified_column()) { + new_expr.push(expr[i].clone()); + new_fields.push(field.clone()); + } + }); - // gather the new set of required columns - expr_to_columns(&expr[i], &mut new_required_columns) - } else { - Ok(()) - } - })?; + for e in new_expr.iter() { + expr_to_columns(e, &mut new_required_columns)? + } let new_input = optimize_plan( _optimizer, @@ -1012,7 +1011,7 @@ mod tests { } fn optimize(plan: &LogicalPlan) -> Result { - let rule = ProjectionPushDown::new(); + let rule = PushDownProjection::new(); rule.optimize(plan, &mut OptimizerConfig::new()) } }