Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Expr::InList to Substrait::RexType #6604

Merged
merged 7 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct ExprSimplifier<S> {
info: S,
}

const THRESHOLD_INLINE_INLIST: usize = 3;
pub const THRESHOLD_INLINE_INLIST: usize = 3;

impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
Expand Down
86 changes: 85 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::logical_expr::expr::Sort;
use datafusion::logical_expr::expr::{InList, Sort};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -67,6 +67,8 @@ use crate::variation_const::{
enum ScalarFunctionType {
Builtin(BuiltinScalarFunction),
Op(Operator),
// logical negation
Not,
}

pub fn name_to_op(name: &str) -> Result<Operator> {
Expand Down Expand Up @@ -116,6 +118,20 @@ fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
)))
}

fn scalar_function_or_not(name: &str) -> Result<ScalarFunctionType> {
if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
return Ok(ScalarFunctionType::Builtin(fun));
}

if name == "not" {
return Ok(ScalarFunctionType::Not);
}

Err(DataFusionError::NotImplemented(format!(
"Unsupported function name: {name:?}"
)))
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
Expand Down Expand Up @@ -660,6 +676,21 @@ pub async fn from_substrait_rex(
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match &e.rex_type {
Some(RexType::SingularOrList(s)) => {
let substrait_expr = s.value.as_ref().unwrap();
let substrait_list = s.options.as_ref();
Ok(Arc::new(Expr::InList(InList {
expr: Box::new(
from_substrait_rex(substrait_expr, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
list: from_substrait_rex_vec(substrait_list, input_schema, extensions)
.await?,
negated: false,
})))
}
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
Some(StructField(x)) => match &x.child.as_ref() {
Expand Down Expand Up @@ -790,13 +821,66 @@ pub async fn from_substrait_rex(
],
})))
}
Ok(ScalarFunctionType::Not) => {
Err(DataFusionError::NotImplemented(
"Not expected function type: Not".to_string(),
))
}
Err(e) => Err(e),
}
}
(l, r) => Err(DataFusionError::NotImplemented(format!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
))),
},
// ScalarFunction or Expr::Not
1 => {
let fun = match extensions.get(&f.function_reference) {
Some(fname) => scalar_function_or_not(fname),
None => Err(DataFusionError::NotImplemented(format!(
"Function not found: function reference = {:?}",
f.function_reference
))),
};

match fun {
Ok(ScalarFunctionType::Op(_)) => {
Err(DataFusionError::NotImplemented(
"Not expected function type: Op".to_string(),
))
}
Ok(scalar_function_type) => {
match &f.arguments.first().unwrap().arg_type {
Some(ArgType::Value(e)) => {
let expr =
from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
match scalar_function_type {
ScalarFunctionType::Builtin(fun) => Ok(Arc::new(
Expr::ScalarFunction(expr::ScalarFunction {
fun,
args: vec![expr],
}),
)),
ScalarFunctionType::Not => {
Ok(Arc::new(Expr::Not(Box::new(expr))))
}
_ => Err(DataFusionError::NotImplemented(
"Invalid arguments for Not expression"
.to_string(),
)),
}
}
_ => Err(DataFusionError::NotImplemented(
"Invalid arguments for Not expression".to_string(),
)),
}
}
Err(e) => Err(e),
}
}
// ScalarFunction
_ => {
let fun = match extensions.get(&f.function_reference) {
Expand Down
43 changes: 41 additions & 2 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction,
BinaryExpr, Case, Cast, InList, ScalarFunction as DFScalarFunction, Sort,
WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::Expr;
Expand All @@ -48,7 +49,7 @@ use substrait::{
window_function::bound::Kind as BoundKind,
window_function::Bound,
FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType,
ScalarFunction, WindowFunction as SubstraitWindowFunction,
ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction,
},
extensions::{
self,
Expand Down Expand Up @@ -614,6 +615,44 @@ pub fn to_substrait_rex(
),
) -> Result<Expression> {
match expr {
Expr::InList(InList {
expr,
list,
negated,
}) => {
let substrait_list = list
.iter()
.map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info))
.collect::<Result<Vec<Expression>>>()?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;

let substrait_or_list = Expression {
rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList {
value: Some(Box::new(substrait_expr)),
options: substrait_list,
}))),
};

if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_or_list)),
}],
output_type: None,
args: vec![],
options: vec![],
})),
})
} else {
Ok(substrait_or_list)
}
}
Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
Expand Down
26 changes: 25 additions & 1 deletion datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mod tests {
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;
use substrait::proto::extensions::simple_extension_declaration::MappingType;

Expand Down Expand Up @@ -334,10 +335,32 @@ mod tests {
}

#[tokio::test]
async fn roundtrip_inlist() -> Result<()> {
async fn roundtrip_inlist_1() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a IN (1, 2, 3)").await
}

#[tokio::test]
// Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_2() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await
}

#[tokio::test]
// Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_3() -> Result<()> {
let inlist = (0..THRESHOLD_INLINE_INLIST + 1)
.map(|i| format!("'{}'", i))
.collect::<Vec<_>>()
.join(", ");

roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await
}

#[tokio::test]
async fn roundtrip_inlist_4() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
Expand Down Expand Up @@ -638,6 +661,7 @@ mod tests {
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
Field::new("e", DataType::UInt32, true),
Field::new("f", DataType::Utf8, true),
]);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
Expand Down
6 changes: 3 additions & 3 deletions datafusion/substrait/tests/testdata/data.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
a,b,c,d,e
1,2.0,2020-01-01,false,4294967296
3,4.5,2020-01-01,true,2147483648
a,b,c,d,e,f
1,2.0,2020-01-01,false,4294967296,'a'
3,4.5,2020-01-01,true,2147483648,'b'