diff --git a/src/consumer.rs b/src/consumer.rs index d1b9e57..79c8ede 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -1,13 +1,13 @@ use async_recursion::async_recursion; use datafusion::common::{DFField, DFSchema, DFSchemaRef}; -use datafusion::logical_expr::{aggregate_function, LogicalPlan}; +use datafusion::logical_expr::{aggregate_function, LogicalPlan, LogicalPlanBuilder}; use datafusion::logical_plan::build_join_schema; use datafusion::prelude::JoinType; use datafusion::{ error::{DataFusionError, Result}, logical_plan::{Expr, Operator}, optimizer::utils::split_conjunction, - prelude::{Column, DataFrame, SessionContext}, + prelude::{Column, SessionContext}, scalar::ScalarValue, }; @@ -67,7 +67,7 @@ pub fn name_to_op(name: &str) -> Result { } /// Convert Substrait Plan to DataFusion DataFrame -pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result> { +pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result { // Register function extension let function_extension = plan .extensions @@ -113,17 +113,18 @@ pub async fn from_substrait_rel( ctx: &mut SessionContext, rel: &Rel, extensions: &HashMap, -) -> Result> { +) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut exprs: Vec = vec![]; for e in &p.expressions { let x = from_substrait_rex(e, &input.schema(), extensions).await?; exprs.push(x.as_ref().clone()); } - input.select(exprs) + input.project(exprs)?.build() } else { Err(DataFusionError::NotImplemented( "Projection without an input is not supported".to_string(), @@ -132,10 +133,11 @@ pub async fn from_substrait_rel( } Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); if let Some(condition) = filter.condition.as_ref() { let expr = from_substrait_rex(condition, &input.schema(), extensions).await?; - input.filter(expr.as_ref().clone()) + input.filter(expr.as_ref().clone())?.build() } else { Err(DataFusionError::NotImplemented( "Filter without an condition is not valid".to_string(), @@ -149,10 +151,11 @@ pub async fn from_substrait_rel( } Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let offset = fetch.offset as usize; let count = fetch.count as usize; - input.limit(offset, Some(count)) + input.limit(offset, Some(count))?.build() } else { Err(DataFusionError::NotImplemented( "Fetch without an input is not valid".to_string(), @@ -161,7 +164,8 @@ pub async fn from_substrait_rel( } Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut sorts: Vec = vec![]; for s in &sort.sorts { let expr = @@ -205,7 +209,7 @@ pub async fn from_substrait_rel( nulls_first: nulls_first, }); } - input.sort(sorts) + input.sort(sorts)?.build() } else { Err(DataFusionError::NotImplemented( "Sort without an input is not valid".to_string(), @@ -214,7 +218,8 @@ pub async fn from_substrait_rel( } Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut group_expr = vec![]; let mut aggr_expr = vec![]; @@ -263,7 +268,7 @@ pub async fn from_substrait_rel( aggr_expr.push(agg_func?.as_ref().clone()); } - input.aggregate(group_expr, aggr_expr) + input.aggregate(group_expr, aggr_expr)?.build() } else { Err(DataFusionError::NotImplemented( "Aggregate without an input is not valid".to_string(), @@ -271,7 +276,9 @@ pub async fn from_substrait_rel( } } Some(RelType::Join(join)) => { - let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?; + let left = LogicalPlanBuilder::from( + from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?, + ); let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), extensions).await?; let join_type = match join.r#type { 1 => JoinType::Inner, @@ -287,7 +294,7 @@ pub async fn from_substrait_rel( let on = from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?; split_conjunction(&on, &mut predicates); - let pairs = predicates + let pairs: Vec<(Column, Column)> = predicates .iter() .map(|p| match p { Expr::BinaryExpr { @@ -295,7 +302,7 @@ pub async fn from_substrait_rel( op: Operator::Eq, right, } => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())), + (Expr::Column(l), Expr::Column(r)) => Ok((l.clone(), r.clone())), _ => { return Err(DataFusionError::Internal( "invalid join condition".to_string(), @@ -309,9 +316,9 @@ pub async fn from_substrait_rel( } }) .collect::>>()?; - let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect(); - let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| r.as_str()).collect(); - left.join(right, join_type, &left_cols, &right_cols, None) + let (left_cols, right_cols): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip(); + left.join(&right, join_type, (left_cols, right_cols), None)? + .build() } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { @@ -334,7 +341,7 @@ pub async fn from_substrait_rel( table: &nt.names[2], }, }; - let t = ctx.table(table_reference)?; + let t = ctx.table(table_reference)?.to_logical_plan()?; match &read.projection { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { @@ -343,7 +350,7 @@ pub async fn from_substrait_rel( .iter() .map(|item| item.field as usize) .collect(); - match t.to_logical_plan()? { + match &t { LogicalPlan::TableScan(scan) => { let mut scan = scan.clone(); let fields: Vec = column_indices @@ -354,8 +361,7 @@ pub async fn from_substrait_rel( scan.projected_schema = DFSchemaRef::new( DFSchema::new_with_metadata(fields, HashMap::new())?, ); - let plan = LogicalPlan::TableScan(scan); - Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) + Ok(LogicalPlan::TableScan(scan)) } _ => Err(DataFusionError::Internal( "unexpected plan for table".to_string(), diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index eb6d967..5956cac 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -162,8 +162,7 @@ mod tests { let df = ctx.sql(sql).await?; let plan = df.to_logical_plan()?; let proto = to_substrait_plan(&plan)?; - let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; let plan2str = format!("{:?}", plan2); assert_eq!(expected_plan_str, &plan2str); Ok(()) @@ -174,9 +173,8 @@ mod tests { let df = ctx.sql(sql).await?; let plan1 = df.to_logical_plan()?; let proto = to_substrait_plan(&plan1)?; - - let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = ctx.optimize(&plan2)?; // Format plan string and replace all None's with 0 let plan1str = format!("{:?}", plan1).replace("None", "0"); @@ -194,15 +192,11 @@ mod tests { let df_a = ctx.sql(sql_with_alias).await?; let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?; - let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a) - .await? - .to_logical_plan()?; + let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?; let df = ctx.sql(sql_no_alias).await?; let proto = to_substrait_plan(&df.to_logical_plan()?)?; - let plan = from_substrait_plan(&mut ctx, &proto) - .await? - .to_logical_plan()?; + let plan = from_substrait_plan(&mut ctx, &proto).await?; println!("{:#?}", plan_with_alias); println!("{:#?}", plan); @@ -218,9 +212,8 @@ mod tests { let df = ctx.sql(sql).await?; let plan = df.to_logical_plan()?; let proto = to_substrait_plan(&plan)?; - - let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = ctx.optimize(&plan2)?; println!("{:#?}", plan); println!("{:#?}", plan2); diff --git a/tests/serialize.rs b/tests/serialize.rs index 618bb9c..4903a36 100644 --- a/tests/serialize.rs +++ b/tests/serialize.rs @@ -23,8 +23,9 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan = df.to_logical_plan()?; + let plan = from_substrait_plan(&mut ctx, &proto).await?; + let plan = ctx.optimize(&plan)?; + let plan_str_ref = format!("{:?}", plan_ref); let plan_str = format!("{:?}", plan); assert_eq!(plan_str_ref, plan_str);