Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(substrait): remove optimize calls from substrait consumer #12800

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/substrait/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
//!
//! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan
//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?;
//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?;
//! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
//! # Ok(())
//! # }
Expand Down
163 changes: 97 additions & 66 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ use crate::variation_const::{
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand All @@ -69,7 +68,7 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
Expand Down Expand Up @@ -227,7 +226,6 @@ pub async fn from_substrait_plan(
// Nothing to do if the schema is already equivalent
return Ok(plan);
}

match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
Expand Down Expand Up @@ -327,54 +325,35 @@ pub async fn from_substrait_extended_expr(
})
}

/// parse projection
pub fn extract_projection(
t: LogicalPlan,
projection: &::core::option::Option<expression::MaskExpression>,
) -> Result<LogicalPlan> {
match projection {
pub fn apply_masking(
schema: DFSchema,
mask_expression: &::core::option::Option<expression::MaskExpression>,
) -> Result<DFSchema> {
match mask_expression {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
match t {
LogicalPlan::TableScan(mut scan) => {
let fields = column_indices
.iter()
.map(|i| scan.projected_schema.qualified_field(*i))
.map(|(qualifier, field)| {
(qualifier.cloned(), Arc::new(field.clone()))
})
.collect();
scan.projection = Some(column_indices);
scan.projected_schema = DFSchemaRef::new(
DFSchema::new_with_metadata(fields, HashMap::new())?,
);
Ok(LogicalPlan::TableScan(scan))
}
LogicalPlan::Projection(projection) => {
// create another Projection around the Projection to handle the field masking
let fields: Vec<Expr> = column_indices
.into_iter()
.map(|i| {
let (qualifier, field) =
projection.schema.qualified_field(i);
let column =
Column::new(qualifier.cloned(), field.name());
Expr::Column(column)
})
.collect();
project(LogicalPlan::Projection(projection), fields)
}
_ => plan_err!("unexpected plan for table"),
}

let fields = column_indices
.iter()
.map(|i| schema.qualified_field(*i))
.map(|(qualifier, field)| {
(qualifier.cloned(), Arc::new(field.clone()))
})
.collect();

Ok(DFSchema::new_with_metadata(
fields,
schema.metadata().clone(),
)?)
}
_ => Ok(t),
None => Ok(schema),
},
_ => Ok(t),
None => Ok(schema),
}
}

Expand Down Expand Up @@ -777,14 +756,20 @@ pub async fn from_substrait_rel(
},
};

let t = ctx.table(table_reference.clone()).await?;

let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
.replace_qualifier(table_reference.clone());
.replace_qualifier(table_reference);

let t = ctx.table(table_reference.clone()).await?;
let t = ensure_schema_compatability(t, substrait_schema)?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
ensure_schema_compatability(
t.schema().to_owned(),
substrait_schema.clone(),
)?;

let substrait_schema = apply_masking(substrait_schema, &read.projection)?;

apply_projection(t, substrait_schema)
}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
Expand Down Expand Up @@ -835,6 +820,10 @@ pub async fn from_substrait_rel(
}))
}
Some(ReadType::LocalFiles(lf)) => {
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for LocalFiles")
})?;

fn extract_filename(name: &str) -> Option<String> {
let corrected_url =
if name.starts_with("file://") && !name.starts_with("file:///") {
Expand Down Expand Up @@ -865,9 +854,20 @@ pub async fn from_substrait_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
let t = ctx.table(table_reference.clone()).await?;

let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
.replace_qualifier(table_reference);

ensure_schema_compatability(
t.schema().to_owned(),
substrait_schema.clone(),
)?;

let substrait_schema = apply_masking(substrait_schema, &read.projection)?;

apply_projection(t, substrait_schema)
}
_ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type),
},
Expand Down Expand Up @@ -984,30 +984,61 @@ pub async fn from_substrait_rel(
/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The
/// DataFusion schema may have MORE fields, but not the other way around.
/// 2. All fields are compatible. See [`ensure_field_compatability`] for details
///
/// This function returns a DataFrame with fields adjusted if necessary in the event that the
/// Substrait schema is a subset of the DataFusion schema.
fn ensure_schema_compatability(
table: DataFrame,
table_schema: DFSchema,
substrait_schema: DFSchema,
) -> Result<DataFrame> {
let df_schema = table.schema().to_owned().strip_qualifiers();
if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(table);
}
let selected_columns = substrait_schema
) -> Result<()> {
substrait_schema
.strip_qualifiers()
.fields()
.iter()
.map(|substrait_field| {
.try_for_each(|substrait_field| {
let df_field =
df_schema.field_with_unqualified_name(substrait_field.name())?;
ensure_field_compatability(df_field, substrait_field)?;
Ok(col(format!("\"{}\"", df_field.name())))
table_schema.field_with_unqualified_name(substrait_field.name())?;
ensure_field_compatability(df_field, substrait_field)
})
.collect::<Result<_>>()?;
}

/// This function returns a DataFrame with fields adjusted if necessary in the event that the
/// Substrait schema is a subset of the DataFusion schema.
fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result<LogicalPlan> {
let df_schema = table.schema().to_owned();

table.select(selected_columns)
let t = table.into_unoptimized_plan();

if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(t);
}

match t {
LogicalPlan::TableScan(mut scan) => {
let column_indices: Vec<usize> = substrait_schema
.strip_qualifiers()
.fields()
.iter()
.map(|substrait_field| {
Ok(df_schema
.index_of_column_by_name(None, substrait_field.name().as_str())
.unwrap())
})
.collect::<Result<_>>()?;

let fields = column_indices
.iter()
.map(|i| df_schema.qualified_field(*i))
.map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone())))
.collect();

scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata(
fields,
df_schema.metadata().clone(),
)?);
scan.projection = Some(column_indices);

Ok(LogicalPlan::TableScan(scan))
}
_ => plan_err!("DataFrame passed to apply_projection must be a TableScan"),
}
}

/// Ensures that the given Substrait field is compatible with the given DataFusion field
Expand Down
Loading
Loading