Skip to content
This repository has been archived by the owner on Jan 13, 2023. It is now read-only.

Change API to return LogicalPlan instead of DataFrame #35

Merged
merged 3 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions src/consumer.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use async_recursion::async_recursion;
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{aggregate_function, LogicalPlan};
use datafusion::logical_expr::{aggregate_function, LogicalPlan, LogicalPlanBuilder};
use datafusion::logical_plan::build_join_schema;
use datafusion::prelude::JoinType;
use datafusion::{
error::{DataFusionError, Result},
logical_plan::{Expr, Operator},
optimizer::utils::split_conjunction,
prelude::{Column, DataFrame, SessionContext},
prelude::{Column, SessionContext},
scalar::ScalarValue,
};

Expand Down Expand Up @@ -67,7 +67,7 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result<Arc<DataFrame>> {
pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result<LogicalPlan> {
// Register function extension
let function_extension = plan
.extensions
Expand Down Expand Up @@ -113,17 +113,18 @@ pub async fn from_substrait_rel(
ctx: &mut SessionContext,
rel: &Rel,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<DataFrame>> {
) -> Result<LogicalPlan> {
match &rel.rel_type {
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
let input = from_substrait_rel(ctx, input, extensions).await?;
let input =
LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?);
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
let x = from_substrait_rex(e, &input.schema(), extensions).await?;
exprs.push(x.as_ref().clone());
}
input.select(exprs)
input.project(exprs)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Projection without an input is not supported".to_string(),
Expand All @@ -132,10 +133,11 @@ pub async fn from_substrait_rel(
}
Some(RelType::Filter(filter)) => {
if let Some(input) = filter.input.as_ref() {
let input = from_substrait_rel(ctx, input, extensions).await?;
let input =
LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?);
if let Some(condition) = filter.condition.as_ref() {
let expr = from_substrait_rex(condition, &input.schema(), extensions).await?;
input.filter(expr.as_ref().clone())
input.filter(expr.as_ref().clone())?.build()
} else {
Err(DataFusionError::NotImplemented(
"Filter without an condition is not valid".to_string(),
Expand All @@ -149,10 +151,11 @@ pub async fn from_substrait_rel(
}
Some(RelType::Fetch(fetch)) => {
if let Some(input) = fetch.input.as_ref() {
let input = from_substrait_rel(ctx, input, extensions).await?;
let input =
LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?);
let offset = fetch.offset as usize;
let count = fetch.count as usize;
input.limit(offset, Some(count))
input.limit(offset, Some(count))?.build()
} else {
Err(DataFusionError::NotImplemented(
"Fetch without an input is not valid".to_string(),
Expand All @@ -161,7 +164,8 @@ pub async fn from_substrait_rel(
}
Some(RelType::Sort(sort)) => {
if let Some(input) = sort.input.as_ref() {
let input = from_substrait_rel(ctx, input, extensions).await?;
let input =
LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?);
let mut sorts: Vec<Expr> = vec![];
for s in &sort.sorts {
let expr =
Expand Down Expand Up @@ -205,7 +209,7 @@ pub async fn from_substrait_rel(
nulls_first: nulls_first,
});
}
input.sort(sorts)
input.sort(sorts)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Sort without an input is not valid".to_string(),
Expand All @@ -214,7 +218,8 @@ pub async fn from_substrait_rel(
}
Some(RelType::Aggregate(agg)) => {
if let Some(input) = agg.input.as_ref() {
let input = from_substrait_rel(ctx, input, extensions).await?;
let input =
LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?);
let mut group_expr = vec![];
let mut aggr_expr = vec![];

Expand Down Expand Up @@ -263,15 +268,17 @@ pub async fn from_substrait_rel(
aggr_expr.push(agg_func?.as_ref().clone());
}

input.aggregate(group_expr, aggr_expr)
input.aggregate(group_expr, aggr_expr)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Aggregate without an input is not valid".to_string(),
))
}
}
Some(RelType::Join(join)) => {
let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?;
let left = LogicalPlanBuilder::from(
from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?,
);
let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), extensions).await?;
let join_type = match join.r#type {
1 => JoinType::Inner,
Expand All @@ -287,15 +294,15 @@ pub async fn from_substrait_rel(
let on =
from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?;
split_conjunction(&on, &mut predicates);
let pairs = predicates
let pairs: Vec<(Column, Column)> = predicates
.iter()
.map(|p| match p {
Expr::BinaryExpr {
left,
op: Operator::Eq,
right,
} => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())),
(Expr::Column(l), Expr::Column(r)) => Ok((l.clone(), r.clone())),
_ => {
return Err(DataFusionError::Internal(
"invalid join condition".to_string(),
Expand All @@ -309,9 +316,9 @@ pub async fn from_substrait_rel(
}
})
.collect::<Result<Vec<_>>>()?;
let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect();
let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| r.as_str()).collect();
left.join(right, join_type, &left_cols, &right_cols, None)
let (left_cols, right_cols): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip();
left.join(&right, join_type, (left_cols, right_cols), None)?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the last parameter (filter) that's set to None the same field semantically as Substrait JoinRel's post_join_filter?

If not, please ignore the comment.

If so, should we check if the JoinRel has post_join_filter? And if so, parse it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think those probably are the same concept. It would be good to implement this in a separate PR.

.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
Expand All @@ -334,7 +341,7 @@ pub async fn from_substrait_rel(
table: &nt.names[2],
},
};
let t = ctx.table(table_reference)?;
let t = ctx.table(table_reference)?.to_logical_plan()?;
match &read.projection {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
Expand All @@ -343,7 +350,7 @@ pub async fn from_substrait_rel(
.iter()
.map(|item| item.field as usize)
.collect();
match t.to_logical_plan()? {
match &t {
LogicalPlan::TableScan(scan) => {
let mut scan = scan.clone();
let fields: Vec<DFField> = column_indices
Expand All @@ -354,8 +361,7 @@ pub async fn from_substrait_rel(
scan.projected_schema = DFSchemaRef::new(
DFSchema::new_with_metadata(fields, HashMap::new())?,
);
let plan = LogicalPlan::TableScan(scan);
Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan)))
Ok(LogicalPlan::TableScan(scan))
}
_ => Err(DataFusionError::Internal(
"unexpected plan for table".to_string(),
Expand Down
21 changes: 7 additions & 14 deletions tests/roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ mod tests {
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_plan(&plan)?;
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2str = format!("{:?}", plan2);
assert_eq!(expected_plan_str, &plan2str);
Ok(())
Expand All @@ -174,9 +173,8 @@ mod tests {
let df = ctx.sql(sql).await?;
let plan1 = df.to_logical_plan()?;
let proto = to_substrait_plan(&plan1)?;

let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = ctx.optimize(&plan2)?;

// Format plan string and replace all None's with 0
let plan1str = format!("{:?}", plan1).replace("None", "0");
Expand All @@ -194,15 +192,11 @@ mod tests {

let df_a = ctx.sql(sql_with_alias).await?;
let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?;
let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a)
.await?
.to_logical_plan()?;
let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?;

let df = ctx.sql(sql_no_alias).await?;
let proto = to_substrait_plan(&df.to_logical_plan()?)?;
let plan = from_substrait_plan(&mut ctx, &proto)
.await?
.to_logical_plan()?;
let plan = from_substrait_plan(&mut ctx, &proto).await?;

println!("{:#?}", plan_with_alias);
println!("{:#?}", plan);
Expand All @@ -218,9 +212,8 @@ mod tests {
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_plan(&plan)?;

let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = ctx.optimize(&plan2)?;

println!("{:#?}", plan);
println!("{:#?}", plan2);
Expand Down
5 changes: 3 additions & 2 deletions tests/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ mod tests {
// Read substrait plan from file
let proto = serializer::deserialize(path).await?;
// Check plan equality
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan = df.to_logical_plan()?;
let plan = from_substrait_plan(&mut ctx, &proto).await?;
let plan = ctx.optimize(&plan)?;

let plan_str_ref = format!("{:?}", plan_ref);
let plan_str = format!("{:?}", plan);
assert_eq!(plan_str_ref, plan_str);
Expand Down