diff --git a/benchmarks/expected-plans/q12.txt b/benchmarks/expected-plans/q12.txt index ef3bab4b7096c..547d12e9f19a2 100644 --- a/benchmarks/expected-plans/q12.txt +++ b/benchmarks/expected-plans/q12.txt @@ -2,6 +2,6 @@ Sort: lineitem.l_shipmode ASC NULLS LAST Projection: lineitem.l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] Inner Join: lineitem.l_orderkey = orders.o_orderkey - Filter: lineitem.l_shipmode IN ([Utf8("MAIL"), Utf8("SHIP")]) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131") + Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131") TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode] TableScan: orders projection=[o_orderkey, o_orderpriority] \ No newline at end of file diff --git a/benchmarks/expected-plans/q19.txt b/benchmarks/expected-plans/q19.txt index 552d743917dc2..3efc3718d01a7 100644 --- a/benchmarks/expected-plans/q19.txt +++ b/benchmarks/expected-plans/q19.txt @@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re Projection: lineitem.l_extendedprice, lineitem.l_discount Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) Inner Join: lineitem.l_partkey = part.p_partkey - Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") + Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) TableScan: part projection=[p_partkey, p_brand, p_size, p_container] diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 17e09497fafe1..a0fbba2f84fb3 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -2040,7 +2040,9 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false }"; + + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }"; + let actual = format!("{:?}", execution_plan); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6a3723644abb6..383fb05c5d00e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -40,6 +40,8 @@ pub struct ExprSimplifier { info: S, } +const THRESHOLD_INLINE_INLIST: usize = 3; + impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an /// instance of [`SimplifyContext`]. See @@ -365,7 +367,48 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { None => lit_bool_null(), } } + // expr IN () --> false + // expr NOT IN () --> true + Expr::InList { + expr, + list, + negated, + } if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + lit(negated) + } + // if expr is a single column reference: + // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) + Expr::InList { + expr, + list, + negated, + } if !list.is_empty() + && ( + // For lists with only 1 value we allow more complex expressions to be simplified + // e.g SUBSTR(c1, 2, 3) IN ('1') -> SUBSTR(c1, 2, 3) = '1' + // for more than one we avoid repeating this potentially expensive + // expressions + list.len() == 1 + || list.len() <= THRESHOLD_INLINE_INLIST + && expr.try_into_col().is_ok() + ) => + { + let first_val = list[0].clone(); + if negated { + list.into_iter() + .skip(1) + .fold((*expr.clone()).not_eq(first_val), |acc, y| { + (*expr.clone()).not_eq(y).and(acc) + }) + } else { + list.into_iter() + .skip(1) + .fold((*expr.clone()).eq(first_val), |acc, y| { + (*expr.clone()).eq(y).or(acc) + }) + } + } // // Rules for NotEq // @@ -1749,6 +1792,37 @@ mod tests { assert_eq!(expected_expr, result); } + #[test] + fn simplify_inlist() { + assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false)); + assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true)); + + assert_eq!( + simplify(in_list(col("c1"), vec![lit(1)], false)), + col("c1").eq(lit(1)) + ); + assert_eq!( + simplify(in_list(col("c1"), vec![lit(1)], true)), + col("c1").not_eq(lit(1)) + ); + + // more complex expressions can be simplified if list contains + // one element only + assert_eq!( + simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)), + (col("c1") * lit(10)).eq(lit(2)) + ); + + assert_eq!( + simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)), + col("c1").eq(lit(2)).or(col("c1").eq(lit(1))) + ); + assert_eq!( + simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)), + col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1))) + ); + } + #[test] fn simplify_expr_bool_and() { // col & true is always col diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index af0da6b416b21..3fda5817f98b2 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -706,7 +706,8 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\ + let expected = + "Filter: test.d != Int32(3) AND test.d != Int32(2) AND test.d != Int32(1)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -721,7 +722,8 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\ + let expected = + "Filter: test.d = Int32(3) OR test.d = Int32(2) OR test.d = Int32(1)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected);