Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[iteration #1] fix(substrait): Do not add implicit groupBy expressions when building logical plans from Substrait #14553

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,12 +1089,33 @@ impl LogicalPlanBuilder {
self,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
self.aggregate_inner(group_expr, aggr_expr, true)
}

pub fn aggregate_without_implicit_group_by_exprs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree adding LogicalPlanBuilder::aggregate_without_implicit_group_by_exprs is not good (especially without documentation explaining the difference)

What I suggest we do (perhaps as a different PR) is to add a flag to the builder to control this behavior

struct LogicalPlanBuilder { 
...
  /// Should the plan builder add implicit group bys to the plan based on constraints
  add_implicit_group_by_exprs: bool,
}

Then when that behavior is needed (in the sql planner) it could be enabled like

        input
            .with_add_implicit_group_by_exprs(true) // new method to see the flag
            .aggregate(group_exprs, aggr_exprs)?
            .build()

Is this something you would be willing to try @anlinc or @Blizzara ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm taking a look now!

It does indeed make sense to have this disabled by default, and enabled only on the SQL path.

I also want to experiment with @Blizzara's suggestion -- we could inline the additional expressions change on the SQL plan path instead. Part of why we may not want a variable is:

  • It really only applies to one construct in the builder (aggregations).
  • It's probably not a popular configuration to use.

self,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
self.aggregate_inner(group_expr, aggr_expr, false)
}

fn aggregate_inner(
self,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
include_implicit_group_by_exprs: bool,
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;

let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
let group_expr = if include_implicit_group_by_exprs {
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?
} else {
group_expr
};

Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::new)
Expand Down Expand Up @@ -1235,7 +1256,7 @@ impl LogicalPlanBuilder {
.map(|(l, r)| {
let left_key = l.into();
let right_key = r.into();
let mut left_using_columns = HashSet::new();
let mut left_using_columns = HashSet::new();
expr_to_columns(&left_key, &mut left_using_columns)?;
let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check(
left_key,
Expand All @@ -1253,12 +1274,12 @@ impl LogicalPlanBuilder {

// find valid equijoin
find_valid_equijoin_key_pair(
&normalized_left_key,
&normalized_right_key,
self.plan.schema(),
right.schema(),
)?.ok_or_else(||
plan_datafusion_err!(
&normalized_left_key,
&normalized_right_key,
self.plan.schema(),
right.schema(),
)?.ok_or_else(||
plan_datafusion_err!(
"can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})"
))
})
Expand Down Expand Up @@ -1495,7 +1516,7 @@ pub fn validate_unique_names<'a>(
None => {
unique_names.insert(name, (position, expr));
Ok(())
},
}
Some((existing_position, existing_expr)) => {
plan_err!("{node_name} require unique expression names \
but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \
Expand Down Expand Up @@ -1962,7 +1983,6 @@ pub fn unnest_with_options(

#[cfg(test)]
mod tests {

use super::*;
use crate::logical_plan::StringifiedPlan;
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};
Expand Down
7 changes: 6 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,12 @@ pub async fn from_aggregate_rel(
};
aggr_exprs.push(agg_func?.as_ref().clone());
}
input.aggregate(group_exprs, aggr_exprs)?.build()

// Do not include implicit group by expressions (from functional dependencies) when building plans from Substrait.
// Otherwise, the ordinal-based emits applied later will point to incorrect expressions.
input
.aggregate_without_implicit_group_by_exprs(group_exprs, aggr_exprs)?
.build()
} else {
not_impl_err!("Aggregate without an input is not valid")
}
Expand Down
18 changes: 18 additions & 0 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,22 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\
\n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\
\n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\
\n TableScan: sales"
);

Ok(())
}
}
11 changes: 11 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ async fn aggregate_grouping_rollup() -> Result<()> {
).await
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
assert_expected_plan(
"SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a",
"Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\
\n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
\n TableScan: data projection=[a, b]",
true
).await
}

#[tokio::test]
async fn decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"extensionUris": [{
"extensionUriAnchor": 1,
"uri": "/functions_aggregate_generic.yaml"
}, {
"extensionUriAnchor": 2,
"uri": "/functions_arithmetic.yaml"
}, {
"extensionUriAnchor": 3,
"uri": "/functions_string.yaml"
}],
"extensions": [{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 0,
"name": "count:any"
}
}, {
"extensionFunction": {
"extensionUriReference": 2,
"functionAnchor": 1,
"name": "sum:i64"
}
}, {
"extensionFunction": {
"extensionUriReference": 3,
"functionAnchor": 2,
"name": "lower:str"
}
}],
"relations": [{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [2, 3]
}
},
"input": {
"aggregate": {
"common": {
"direct": {
}
},
"input": {
"aggregate": {
"common": {
"direct": {
}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"product"
],
"struct": {
"types": [
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"sales"
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}],
"expressionReferences": []
}],
"measures": [{
"measure": {
"functionReference": 0,
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL",
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}],
"options": []
}
}],
"groupingExpressions": []
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}],
"expressionReferences": []
}],
"measures": [{
"measure": {
"functionReference": 1,
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL",
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}
}],
"options": []
}
}],
"groupingExpressions": []
}
},
"expressions": [{
"scalarFunction": {
"functionReference": 2,
"args": [],
"outputType": {
"string": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}],
"options": []
}
}, {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}]
}
},
"names": ["lower(product)", "product_count"]
}
}],
"expectedTypeUrls": []
}
Loading