Skip to content

Commit

Permalink
fix(substrait): remove optimize calls from substrait consumer (#12800)
Browse files Browse the repository at this point in the history
* fix(substrait): remove optimize calls from substrait consumer

* fix(substrait): fix schema comparison in ensure_schema_compatability

* fix(substrait): correctly apply read projections

* fix(substrait): nits

* fix(substrait): split schema validation and apply_projection

* fix(substrait): return an error when apply_projection is called with something other than a TableScan

* fix(substrait): clippy errors
  • Loading branch information
tokoko authored Oct 15, 2024
1 parent 4a0b768 commit 5a0ea0b
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 141 deletions.
1 change: 1 addition & 0 deletions datafusion/substrait/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
//!
//! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan
//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?;
//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?;
//! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
//! # Ok(())
//! # }
Expand Down
163 changes: 97 additions & 66 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ use crate::variation_const::{
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand All @@ -69,7 +68,7 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
Expand Down Expand Up @@ -227,7 +226,6 @@ pub async fn from_substrait_plan(
// Nothing to do if the schema is already equivalent
return Ok(plan);
}

match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
Expand Down Expand Up @@ -327,54 +325,35 @@ pub async fn from_substrait_extended_expr(
})
}

/// parse projection
pub fn extract_projection(
t: LogicalPlan,
projection: &::core::option::Option<expression::MaskExpression>,
) -> Result<LogicalPlan> {
match projection {
pub fn apply_masking(
schema: DFSchema,
mask_expression: &::core::option::Option<expression::MaskExpression>,
) -> Result<DFSchema> {
match mask_expression {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
match t {
LogicalPlan::TableScan(mut scan) => {
let fields = column_indices
.iter()
.map(|i| scan.projected_schema.qualified_field(*i))
.map(|(qualifier, field)| {
(qualifier.cloned(), Arc::new(field.clone()))
})
.collect();
scan.projection = Some(column_indices);
scan.projected_schema = DFSchemaRef::new(
DFSchema::new_with_metadata(fields, HashMap::new())?,
);
Ok(LogicalPlan::TableScan(scan))
}
LogicalPlan::Projection(projection) => {
// create another Projection around the Projection to handle the field masking
let fields: Vec<Expr> = column_indices
.into_iter()
.map(|i| {
let (qualifier, field) =
projection.schema.qualified_field(i);
let column =
Column::new(qualifier.cloned(), field.name());
Expr::Column(column)
})
.collect();
project(LogicalPlan::Projection(projection), fields)
}
_ => plan_err!("unexpected plan for table"),
}

let fields = column_indices
.iter()
.map(|i| schema.qualified_field(*i))
.map(|(qualifier, field)| {
(qualifier.cloned(), Arc::new(field.clone()))
})
.collect();

Ok(DFSchema::new_with_metadata(
fields,
schema.metadata().clone(),
)?)
}
_ => Ok(t),
None => Ok(schema),
},
_ => Ok(t),
None => Ok(schema),
}
}

Expand Down Expand Up @@ -777,14 +756,20 @@ pub async fn from_substrait_rel(
},
};

let t = ctx.table(table_reference.clone()).await?;

let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
.replace_qualifier(table_reference.clone());
.replace_qualifier(table_reference);

let t = ctx.table(table_reference.clone()).await?;
let t = ensure_schema_compatability(t, substrait_schema)?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
ensure_schema_compatability(
t.schema().to_owned(),
substrait_schema.clone(),
)?;

let substrait_schema = apply_masking(substrait_schema, &read.projection)?;

apply_projection(t, substrait_schema)
}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
Expand Down Expand Up @@ -835,6 +820,10 @@ pub async fn from_substrait_rel(
}))
}
Some(ReadType::LocalFiles(lf)) => {
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for LocalFiles")
})?;

fn extract_filename(name: &str) -> Option<String> {
let corrected_url =
if name.starts_with("file://") && !name.starts_with("file:///") {
Expand Down Expand Up @@ -865,9 +854,20 @@ pub async fn from_substrait_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
let t = ctx.table(table_reference.clone()).await?;

let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
.replace_qualifier(table_reference);

ensure_schema_compatability(
t.schema().to_owned(),
substrait_schema.clone(),
)?;

let substrait_schema = apply_masking(substrait_schema, &read.projection)?;

apply_projection(t, substrait_schema)
}
_ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type),
},
Expand Down Expand Up @@ -995,30 +995,61 @@ pub async fn from_substrait_rel(
/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The
/// DataFusion schema may have MORE fields, but not the other way around.
/// 2. All fields are compatible. See [`ensure_field_compatability`] for details
///
/// This function returns a DataFrame with fields adjusted if necessary in the event that the
/// Substrait schema is a subset of the DataFusion schema.
fn ensure_schema_compatability(
table: DataFrame,
table_schema: DFSchema,
substrait_schema: DFSchema,
) -> Result<DataFrame> {
let df_schema = table.schema().to_owned().strip_qualifiers();
if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(table);
}
let selected_columns = substrait_schema
) -> Result<()> {
substrait_schema
.strip_qualifiers()
.fields()
.iter()
.map(|substrait_field| {
.try_for_each(|substrait_field| {
let df_field =
df_schema.field_with_unqualified_name(substrait_field.name())?;
ensure_field_compatability(df_field, substrait_field)?;
Ok(col(format!("\"{}\"", df_field.name())))
table_schema.field_with_unqualified_name(substrait_field.name())?;
ensure_field_compatability(df_field, substrait_field)
})
.collect::<Result<_>>()?;
}

/// This function returns a DataFrame with fields adjusted if necessary in the event that the
/// Substrait schema is a subset of the DataFusion schema.
fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result<LogicalPlan> {
let df_schema = table.schema().to_owned();

table.select(selected_columns)
let t = table.into_unoptimized_plan();

if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(t);
}

match t {
LogicalPlan::TableScan(mut scan) => {
let column_indices: Vec<usize> = substrait_schema
.strip_qualifiers()
.fields()
.iter()
.map(|substrait_field| {
Ok(df_schema
.index_of_column_by_name(None, substrait_field.name().as_str())
.unwrap())
})
.collect::<Result<_>>()?;

let fields = column_indices
.iter()
.map(|i| df_schema.qualified_field(*i))
.map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone())))
.collect();

scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata(
fields,
df_schema.metadata().clone(),
)?);
scan.projection = Some(column_indices);

Ok(LogicalPlan::TableScan(scan))
}
_ => plan_err!("DataFrame passed to apply_projection must be a TableScan"),
}
}

/// Ensures that the given Substrait field is compatible with the given DataFusion field
Expand Down
Loading

0 comments on commit 5a0ea0b

Please sign in to comment.