Skip to content

Commit

Permalink
Support NOT IN expr
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jun 13, 2023
1 parent dcb5d38 commit 4192e31
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 18 deletions.
86 changes: 85 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::logical_expr::expr::Sort;
use datafusion::logical_expr::expr::{InList, Sort};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -67,6 +67,8 @@ use crate::variation_const::{
enum ScalarFunctionType {
Builtin(BuiltinScalarFunction),
Op(Operator),
// logical negation
Not,
}

pub fn name_to_op(name: &str) -> Result<Operator> {
Expand Down Expand Up @@ -116,6 +118,20 @@ fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
)))
}

fn scalar_function_or_not(name: &str) -> Result<ScalarFunctionType> {
if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
return Ok(ScalarFunctionType::Builtin(fun));
}

if name == "not" {
return Ok(ScalarFunctionType::Not);
}

Err(DataFusionError::NotImplemented(format!(
"Unsupported function name: {name:?}"
)))
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
Expand Down Expand Up @@ -660,6 +676,21 @@ pub async fn from_substrait_rex(
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match &e.rex_type {
Some(RexType::SingularOrList(s)) => {
let substrait_expr = s.value.as_ref().unwrap();
let substrait_list = s.options.as_ref();
Ok(Arc::new(Expr::InList(InList {
expr: Box::new(
from_substrait_rex(substrait_expr, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
list: from_substrait_rex_vec(substrait_list, input_schema, extensions)
.await?,
negated: false,
})))
}
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
Some(StructField(x)) => match &x.child.as_ref() {
Expand Down Expand Up @@ -790,13 +821,66 @@ pub async fn from_substrait_rex(
],
})))
}
Ok(ScalarFunctionType::Not) => {
Err(DataFusionError::NotImplemented(
"Not expected function type: Not".to_string(),
))
}
Err(e) => Err(e),
}
}
(l, r) => Err(DataFusionError::NotImplemented(format!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
))),
},
// ScalarFunction or Expr::Not
1 => {
let fun = match extensions.get(&f.function_reference) {
Some(fname) => scalar_function_or_not(fname),
None => Err(DataFusionError::NotImplemented(format!(
"Function not found: function reference = {:?}",
f.function_reference
))),
};

match fun {
Ok(ScalarFunctionType::Op(_)) => {
Err(DataFusionError::NotImplemented(
"Not expected function type: Op".to_string(),
))
}
Ok(scalar_function_type) => {
match &f.arguments.first().unwrap().arg_type {
Some(ArgType::Value(e)) => {
let expr =
from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
match scalar_function_type {
ScalarFunctionType::Builtin(fun) => Ok(Arc::new(
Expr::ScalarFunction(expr::ScalarFunction {
fun,
args: vec![expr],
}),
)),
ScalarFunctionType::Not => {
Ok(Arc::new(Expr::Not(Box::new(expr))))
}
_ => Err(DataFusionError::NotImplemented(
"Invalid arguments for Not expression"
.to_string(),
)),
}
}
_ => Err(DataFusionError::NotImplemented(
"Invalid arguments for Not expression".to_string(),
)),
}
}
Err(e) => Err(e),
}
}
// ScalarFunction
_ => {
let fun = match extensions.get(&f.function_reference) {
Expand Down
25 changes: 21 additions & 4 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,20 +582,37 @@ pub fn to_substrait_rex(
list,
negated,
}) => {
// expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C)
// negated: expr NOT IN (A, B, ...) --> (expr != A) AND (expr != B) AND (expr != C)
let substrait_list = list
.iter()
.map(|x| to_substrait_rex(x, schema, extension_info))
.collect::<Result<Vec<Expression>>>()?;
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;

Ok(Expression {
let substrait_or_list = Expression {
rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList {
value: Some(Box::new(substrait_expr)),
options: substrait_list,
}))),
})
};

if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_or_list)),
}],
output_type: None,
args: vec![],
options: vec![],
})),
})
} else {
Ok(substrait_or_list)
}
}
Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
Expand Down
31 changes: 18 additions & 13 deletions datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,6 @@ mod tests {
.await
}

#[tokio::test]
// Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_1() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await
}

#[tokio::test]
// Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_2() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await
}

#[tokio::test]
async fn simple_scalar_function_abs() -> Result<()> {
roundtrip("SELECT ABS(a) FROM data").await
Expand Down Expand Up @@ -346,10 +334,27 @@ mod tests {
}

#[tokio::test]
async fn roundtrip_inlist() -> Result<()> {
async fn roundtrip_inlist_1() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a IN (1, 2, 3)").await
}

#[tokio::test]
// Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_2() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await
}

#[tokio::test]
// Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_3() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c', 'd')").await
}

#[tokio::test]
async fn roundtrip_inlist_4() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
Expand Down

0 comments on commit 4192e31

Please sign in to comment.