Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(substrait): Do not add implicit groupBy expressions when building logical plans from Substrait #14860

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use crate::execution::context::{SessionState, TaskContext};
use crate::execution::FunctionRegistry;
use crate::logical_expr::utils::find_window_exprs;
use crate::logical_expr::{
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType,
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
Partitioning, TableType,
};
use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
Expand Down Expand Up @@ -526,7 +527,10 @@ impl DataFrame {
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(self.plan)
.with_options(options)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
Expand Down
121 changes: 112 additions & 9 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand All @@ -63,6 +63,26 @@ use indexmap::IndexSet;
/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";

/// Options for [`LogicalPlanBuilder`]
#[derive(Default, Debug, Clone)]
pub struct LogicalPlanBuilderOptions {
/// Flag indicating whether the plan builder should add
/// functionally dependent expressions as additional aggregation groupings.
add_implicit_group_by_exprs: bool,
}

impl LogicalPlanBuilderOptions {
pub fn new() -> Self {
Default::default()
}

/// Should the builder add functionally dependent expressions as additional aggregation groupings.
pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self {
self.add_implicit_group_by_exprs = add;
self
}
}

/// Builder for logical plans
///
/// # Example building a simple plan
Expand Down Expand Up @@ -103,19 +123,29 @@ pub const UNNAMED_TABLE: &str = "?table?";
#[derive(Debug, Clone)]
pub struct LogicalPlanBuilder {
plan: Arc<LogicalPlan>,
options: LogicalPlanBuilderOptions,
}

impl LogicalPlanBuilder {
/// Create a builder from an existing plan
pub fn new(plan: LogicalPlan) -> Self {
Self {
plan: Arc::new(plan),
options: LogicalPlanBuilderOptions::default(),
}
}

/// Create a builder from an existing plan
pub fn new_from_arc(plan: Arc<LogicalPlan>) -> Self {
Self { plan }
Self {
plan,
options: LogicalPlanBuilderOptions::default(),
}
}

pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self {
self.options = options;
self
}

/// Return the output schema of the plan build so far
Expand Down Expand Up @@ -1138,8 +1168,12 @@ impl LogicalPlanBuilder {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;

let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
let group_expr = if self.options.add_implicit_group_by_exprs {
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?
} else {
group_expr
};

Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::new)
Expand Down Expand Up @@ -1550,6 +1584,7 @@ pub fn add_group_by_exprs_from_dependencies(
}
Ok(group_expr)
}

/// Errors if one or more expressions have equal names.
pub fn validate_unique_names<'a>(
node_name: &str,
Expand Down Expand Up @@ -1685,7 +1720,21 @@ pub fn table_scan_with_filter_and_fetch(

pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource { table_schema })
Arc::new(LogicalTableSource {
table_schema,
constraints: Default::default(),
})
}

pub fn table_source_with_constraints(
table_schema: &Schema,
constraints: Constraints,
) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource {
table_schema,
constraints,
})
}

/// Wrap projection for a plan, if the join keys contains normal expression.
Expand Down Expand Up @@ -1756,12 +1805,21 @@ pub fn wrap_projection_for_join_if_necessary(
/// DefaultTableSource.
pub struct LogicalTableSource {
table_schema: SchemaRef,
constraints: Constraints,
}

impl LogicalTableSource {
/// Create a new LogicalTableSource
pub fn new(table_schema: SchemaRef) -> Self {
Self { table_schema }
Self {
table_schema,
constraints: Constraints::default(),
}
}

pub fn with_constraints(mut self, constraints: Constraints) -> Self {
self.constraints = constraints;
self
}
}

Expand All @@ -1774,6 +1832,10 @@ impl TableSource for LogicalTableSource {
Arc::clone(&self.table_schema)
}

fn constraints(&self) -> Option<&Constraints> {
Some(&self.constraints)
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
Expand Down Expand Up @@ -2023,12 +2085,12 @@ pub fn unnest_with_options(

#[cfg(test)]
mod tests {

use super::*;
use crate::logical_plan::StringifiedPlan;
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};

use datafusion_common::{RecursionUnnestOption, SchemaError};
use crate::test::function_stub::sum;
use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError};

#[test]
fn plan_builder_simple() -> Result<()> {
Expand Down Expand Up @@ -2575,4 +2637,45 @@ mod tests {

Ok(())
}

#[test]
fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);

let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;

let expected =
"Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\
\n TableScan: employee_csv projection=[id, state, salary]";
assert_eq!(expected, format!("{plan}"));

Ok(())
}

#[test]
fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);

let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.with_options(options)
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;

let expected =
"Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\
\n TableScan: employee_csv projection=[id, state, salary]";
assert_eq!(expected, format!("{plan}"));

Ok(())
}
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod tree_node;

pub use builder::{
build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary,
LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE,
LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE,
};
pub use ddl::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
Expand Down
9 changes: 8 additions & 1 deletion datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ use datafusion_expr::utils::{
};
use datafusion_expr::{
qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter,
GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning,
GroupingSet, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
Partitioning,
};

use indexmap::IndexMap;
Expand Down Expand Up @@ -371,7 +372,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let agg_expr = agg.aggr_expr.clone();
let (new_input, new_group_by_exprs) =
self.try_process_group_by_unnest(agg)?;
let options = LogicalPlanBuilderOptions::new()
.with_add_implicit_group_by_exprs(true);
LogicalPlanBuilder::from(new_input)
.with_options(options)
.aggregate(new_group_by_exprs, agg_expr)?
.build()
}
Expand Down Expand Up @@ -744,7 +748,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
aggr_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
// create the aggregate plan
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(input.clone())
.with_options(options)
.aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())?
.build()?;
let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan {
Expand Down
18 changes: 18 additions & 0 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,22 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\
\n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\
\n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\
\n TableScan: sales"
);

Ok(())
}
}
11 changes: 11 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ async fn aggregate_grouping_rollup() -> Result<()> {
).await
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
assert_expected_plan(
"SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a",
"Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\
\n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
\n TableScan: data projection=[a, b]",
true
).await
}

#[tokio::test]
async fn decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
Expand Down
Loading