From a8d0989fbc9a20207270d315c14d9ffa14bc89e6 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 9 Jun 2023 11:02:59 +0800 Subject: [PATCH 1/6] Expr::InList to Substrait::RexType Signed-off-by: jayzhan211 --- .../substrait/src/logical_plan/producer.rs | 44 ++++++++++++++++++- .../substrait/tests/roundtrip_logical_plan.rs | 13 ++++++ datafusion/substrait/tests/testdata/data.csv | 6 +-- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 228341548813..0ffcc97301f7 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -29,7 +29,8 @@ use datafusion::common::DFSchemaRef; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction, + BinaryExpr, Case, Cast, InList, ScalarFunction as DFScalarFunction, Sort, + WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -614,6 +615,47 @@ pub fn to_substrait_rex( ), ) -> Result { match expr { + Expr::InList(InList { + expr, + list, + negated, + }) => { + // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) + // negated: expr NOT IN (A, B, ...) --> (expr != A) AND (expr != B) AND (expr != C) + let op_for_list = match negated { + true => Operator::And, + false => Operator::Or, + }; + let op_for_list_item = match negated { + true => Operator::NotEq, + false => Operator::Eq, + }; + + let substrait_list = list + .iter() + .map(|x| to_substrait_rex(x, schema, extension_info)) + .collect::>>()?; + let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; + + let init_val = make_binary_op_scalar_func( + &substrait_expr, + &substrait_list[0], + op_for_list_item, + extension_info, + ); + + let res = substrait_list.into_iter().skip(1).fold(init_val, |acc, y| { + let val = make_binary_op_scalar_func( + &substrait_expr, + &y, + op_for_list_item, + extension_info, + ); + + make_binary_op_scalar_func(&acc, &val, op_for_list, extension_info) + }); + Ok(res) + } Expr::ScalarFunction(DFScalarFunction { fun, args }) => { let mut arguments: Vec = vec![]; for arg in args { diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index e209ebedc0f3..5ebfcb8f4c73 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -278,6 +278,18 @@ mod tests { .await } + #[tokio::test] + // Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST + async fn roundtrip_inlist_1() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await + } + + #[tokio::test] + // Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST + async fn roundtrip_inlist_2() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await + } + #[tokio::test] async fn simple_scalar_function_abs() -> Result<()> { roundtrip("SELECT ABS(a) FROM data").await @@ -638,6 +650,7 @@ mod tests { Field::new("c", DataType::Date32, true), Field::new("d", DataType::Boolean, true), Field::new("e", DataType::UInt32, true), + Field::new("f", DataType::Utf8, true), ]); explicit_options.schema = Some(&schema); ctx.register_csv("data", "tests/testdata/data.csv", explicit_options) diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 170457da5812..1b85b166b1df 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ -a,b,c,d,e -1,2.0,2020-01-01,false,4294967296 -3,4.5,2020-01-01,true,2147483648 \ No newline at end of file +a,b,c,d,e,f +1,2.0,2020-01-01,false,4294967296,'a' +3,4.5,2020-01-01,true,2147483648,'b' \ No newline at end of file From 784ea2af3a11ade24b1c94cab7124e09eea77171 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 10 Jun 2023 08:35:57 +0800 Subject: [PATCH 2/6] check empty list Signed-off-by: jayzhan211 --- datafusion/substrait/src/logical_plan/producer.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0ffcc97301f7..a98ae7d49643 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -637,9 +637,16 @@ pub fn to_substrait_rex( .collect::>>()?; let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; + if substrait_list.is_empty() { + return Err(DataFusionError::Internal( + "Empty list in IN expression".to_string(), + )); + } + + let first_val = substrait_list.first().unwrap(); let init_val = make_binary_op_scalar_func( &substrait_expr, - &substrait_list[0], + first_val, op_for_list_item, extension_info, ); From 3d8ac7ca2de9cd56fbe9ac67761f45839cf1a7eb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 12 Jun 2023 19:51:48 +0800 Subject: [PATCH 3/6] Utilize SingularOrList Signed-off-by: jayzhan211 --- .../substrait/src/logical_plan/producer.rs | 42 ++++--------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a98ae7d49643..6817198779c9 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -49,7 +49,7 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -622,46 +622,18 @@ pub fn to_substrait_rex( }) => { // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) // negated: expr NOT IN (A, B, ...) --> (expr != A) AND (expr != B) AND (expr != C) - let op_for_list = match negated { - true => Operator::And, - false => Operator::Or, - }; - let op_for_list_item = match negated { - true => Operator::NotEq, - false => Operator::Eq, - }; - let substrait_list = list .iter() .map(|x| to_substrait_rex(x, schema, extension_info)) .collect::>>()?; let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - if substrait_list.is_empty() { - return Err(DataFusionError::Internal( - "Empty list in IN expression".to_string(), - )); - } - - let first_val = substrait_list.first().unwrap(); - let init_val = make_binary_op_scalar_func( - &substrait_expr, - first_val, - op_for_list_item, - extension_info, - ); - - let res = substrait_list.into_iter().skip(1).fold(init_val, |acc, y| { - let val = make_binary_op_scalar_func( - &substrait_expr, - &y, - op_for_list_item, - extension_info, - ); - - make_binary_op_scalar_func(&acc, &val, op_for_list, extension_info) - }); - Ok(res) + Ok(Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }) } Expr::ScalarFunction(DFScalarFunction { fun, args }) => { let mut arguments: Vec = vec![]; From 536fc7ba8951be6b4ffb75238603c668cd30cb31 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 13 Jun 2023 15:00:51 +0800 Subject: [PATCH 4/6] Support NOT IN expr Signed-off-by: jayzhan211 --- .../substrait/src/logical_plan/consumer.rs | 86 ++++++++++++++++++- .../substrait/src/logical_plan/producer.rs | 25 +++++- .../substrait/tests/roundtrip_logical_plan.rs | 31 ++++--- 3 files changed, 124 insertions(+), 18 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f15ffdf42374..4af0c4c6c9ec 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -52,7 +52,7 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; -use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::expr::{InList, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -67,6 +67,8 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), + // logical negation + Not, } pub fn name_to_op(name: &str) -> Result { @@ -116,6 +118,20 @@ fn name_to_op_or_scalar_function(name: &str) -> Result { ))) } +fn scalar_function_or_not(name: &str) -> Result { + if let Ok(fun) = BuiltinScalarFunction::from_str(name) { + return Ok(ScalarFunctionType::Builtin(fun)); + } + + if name == "not" { + return Ok(ScalarFunctionType::Not); + } + + Err(DataFusionError::NotImplemented(format!( + "Unsupported function name: {name:?}" + ))) +} + /// Convert Substrait Plan to DataFusion DataFrame pub async fn from_substrait_plan( ctx: &mut SessionContext, @@ -660,6 +676,21 @@ pub async fn from_substrait_rex( extensions: &HashMap, ) -> Result> { match &e.rex_type { + Some(RexType::SingularOrList(s)) => { + let substrait_expr = s.value.as_ref().unwrap(); + let substrait_list = s.options.as_ref(); + Ok(Arc::new(Expr::InList(InList { + expr: Box::new( + from_substrait_rex(substrait_expr, input_schema, extensions) + .await? + .as_ref() + .clone(), + ), + list: from_substrait_rex_vec(substrait_list, input_schema, extensions) + .await?, + negated: false, + }))) + } Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { Some(StructField(x)) => match &x.child.as_ref() { @@ -790,6 +821,11 @@ pub async fn from_substrait_rex( ], }))) } + Ok(ScalarFunctionType::Not) => { + Err(DataFusionError::NotImplemented( + "Not expected function type: Not".to_string(), + )) + } Err(e) => Err(e), } } @@ -797,6 +833,54 @@ pub async fn from_substrait_rex( "Invalid arguments for binary expression: {l:?} and {r:?}" ))), }, + // ScalarFunction or Expr::Not + 1 => { + let fun = match extensions.get(&f.function_reference) { + Some(fname) => scalar_function_or_not(fname), + None => Err(DataFusionError::NotImplemented(format!( + "Function not found: function reference = {:?}", + f.function_reference + ))), + }; + + match fun { + Ok(ScalarFunctionType::Op(_)) => { + Err(DataFusionError::NotImplemented( + "Not expected function type: Op".to_string(), + )) + } + Ok(scalar_function_type) => { + match &f.arguments.first().unwrap().arg_type { + Some(ArgType::Value(e)) => { + let expr = + from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + match scalar_function_type { + ScalarFunctionType::Builtin(fun) => Ok(Arc::new( + Expr::ScalarFunction(expr::ScalarFunction { + fun, + args: vec![expr], + }), + )), + ScalarFunctionType::Not => { + Ok(Arc::new(Expr::Not(Box::new(expr)))) + } + _ => Err(DataFusionError::NotImplemented( + "Invalid arguments for Not expression" + .to_string(), + )), + } + } + _ => Err(DataFusionError::NotImplemented( + "Invalid arguments for Not expression".to_string(), + )), + } + } + Err(e) => Err(e), + } + } // ScalarFunction _ => { let fun = match extensions.get(&f.function_reference) { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6817198779c9..541e9aa94a68 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -620,20 +620,37 @@ pub fn to_substrait_rex( list, negated, }) => { - // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) - // negated: expr NOT IN (A, B, ...) --> (expr != A) AND (expr != B) AND (expr != C) let substrait_list = list .iter() .map(|x| to_substrait_rex(x, schema, extension_info)) .collect::>>()?; let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - Ok(Expression { + let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { value: Some(Box::new(substrait_expr)), options: substrait_list, }))), - }) + }; + + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } } Expr::ScalarFunction(DFScalarFunction { fun, args }) => { let mut arguments: Vec = vec![]; diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 5ebfcb8f4c73..5cd9e9e4b91d 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -278,18 +278,6 @@ mod tests { .await } - #[tokio::test] - // Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST - async fn roundtrip_inlist_1() -> Result<()> { - roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await - } - - #[tokio::test] - // Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST - async fn roundtrip_inlist_2() -> Result<()> { - roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await - } - #[tokio::test] async fn simple_scalar_function_abs() -> Result<()> { roundtrip("SELECT ABS(a) FROM data").await @@ -346,10 +334,27 @@ mod tests { } #[tokio::test] - async fn roundtrip_inlist() -> Result<()> { + async fn roundtrip_inlist_1() -> Result<()> { roundtrip("SELECT * FROM data WHERE a IN (1, 2, 3)").await } + #[tokio::test] + // Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST + async fn roundtrip_inlist_2() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await + } + + #[tokio::test] + // Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST + async fn roundtrip_inlist_3() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await + } + + #[tokio::test] + async fn roundtrip_inlist_4() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await + } + #[tokio::test] async fn roundtrip_inner_join() -> Result<()> { roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await From dd2a6c2354a71b46d39e3621042ff56e3fc150db Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 13 Jun 2023 17:14:26 +0800 Subject: [PATCH 5/6] address conflict Signed-off-by: jayzhan211 --- datafusion/substrait/src/logical_plan/producer.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 541e9aa94a68..35b8b8e5e372 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -622,9 +622,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, extension_info)) + .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) .collect::>>()?; - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { From ca1f318acaf22ddb02910086c6fdecfba402a350 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 14 Jun 2023 09:19:18 +0800 Subject: [PATCH 6/6] address comment Signed-off-by: jayzhan211 --- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 2 +- datafusion/substrait/tests/roundtrip_logical_plan.rs | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8aebae18c1ae..c15f0d91a50e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -46,7 +46,7 @@ pub struct ExprSimplifier { info: S, } -const THRESHOLD_INLINE_INLIST: usize = 3; +pub const THRESHOLD_INLINE_INLIST: usize = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 5cd9e9e4b91d..5054042501a5 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -31,6 +31,7 @@ mod tests { use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; + use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; use substrait::proto::extensions::simple_extension_declaration::MappingType; @@ -347,7 +348,12 @@ mod tests { #[tokio::test] // Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST async fn roundtrip_inlist_3() -> Result<()> { - roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await + let inlist = (0..THRESHOLD_INLINE_INLIST + 1) + .map(|i| format!("'{}'", i)) + .collect::>() + .join(", "); + + roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await } #[tokio::test]