Skip to content

Commit

Permalink
Add SQL planner support for ROLLUP and CUBE grouping set expressi…
Browse files Browse the repository at this point in the history
…ons (#2446)

* Add SQL planner support for ROLLUP and CUBE grouping sets

* prep for review

* fix more todo comments

* code cleanup

* clippy

* fmt and clippy

* revert change

* clippy
  • Loading branch information
andygrove authored May 9, 2022
1 parent dfdeb42 commit 1fe038f
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 32 deletions.
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Recursion::Continue(self),

Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()),
Expand Down
15 changes: 8 additions & 7 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use std::{
sync::Arc,
};

use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::expr::exprlist_to_fields;
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, provider_as_source,
rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, DFSchemaRef, Limit,
Expand Down Expand Up @@ -557,7 +558,7 @@ impl LogicalPlanBuilder {
expr.extend(missing_exprs);

let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&expr, input_schema)?,
exprlist_to_fields(&expr, &input)?,
input_schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -629,7 +630,7 @@ impl LogicalPlanBuilder {
.map(|f| Expr::Column(f.qualified_column()))
.collect();
let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&new_expr, schema)?,
exprlist_to_fields(&new_expr, &self.plan)?,
schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -843,8 +844,7 @@ impl LogicalPlanBuilder {
let window_expr = normalize_cols(window_expr, &self.plan)?;
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
let mut window_fields: Vec<DFField> =
exprlist_to_fields(all_expr, self.plan.schema())?;
let mut window_fields: Vec<DFField> = exprlist_to_fields(all_expr, &self.plan)?;
window_fields.extend_from_slice(self.plan.schema().fields());
Ok(Self::from(LogicalPlan::Window(Window {
input: Arc::new(self.plan.clone()),
Expand All @@ -869,7 +869,7 @@ impl LogicalPlanBuilder {
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
let aggr_schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, self.plan.schema())?,
exprlist_to_fields(all_expr, &self.plan)?,
self.plan.schema().metadata().clone(),
)?;
Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
Expand Down Expand Up @@ -1126,13 +1126,14 @@ pub fn project_with_alias(
}
validate_unique_names("Projections", projected_expr.iter(), input_schema)?;
let input_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&projected_expr, input_schema)?,
exprlist_to_fields(&projected_expr, &plan)?,
plan.schema().metadata().clone(),
)?;
let schema = match alias {
Some(ref alias) => input_schema.replace_qualifier(alias.as_str()),
None => input_schema,
};

Ok(LogicalPlan::Projection(Projection {
expr: projected_expr,
input: Arc::new(plan.clone()),
Expand Down
31 changes: 28 additions & 3 deletions datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ pub use super::Operator;
use crate::error::Result;
use crate::logical_plan::ExprSchemable;
use crate::logical_plan::{DFField, DFSchema};
use crate::sql::utils::find_columns_referenced_by_expr;
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{AccumulatorFunctionImplementation, LogicalPlan};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
Expand Down Expand Up @@ -138,9 +139,33 @@ pub fn create_udaf(
/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
input_schema: &DFSchema,
plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
match plan {
LogicalPlan::Aggregate(agg) => {
let group_expr: Vec<Column> = agg
.group_expr
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect();
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let mut fields = vec![];
for expr in &exprs {
match expr {
Expr::Column(c) if group_expr.iter().any(|x| x == c) => {
// resolve against schema of input to aggregate
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}
_ => {
let input_schema = &plan.schema();
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
}
}
}

/// Calls a named built in function
Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::logical_plan::ExprSchemable;
use crate::logical_plan::LogicalPlan;
use datafusion_common::Column;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -215,6 +216,22 @@ impl ExprRewritable for Expr {
fun,
distinct,
},
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::Cube(exprs) => {
Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
Expr::GroupingSet(GroupingSet::GroupingSets(
lists_of_exprs
.iter()
.map(|exprs| rewrite_vec(exprs.clone(), rewriter))
.collect::<Result<Vec<_>>>()?,
))
}
},
Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
Expand Down
14 changes: 14 additions & 0 deletions datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use super::Expr;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;

/// Controls how the visitor recursion should proceed.
pub enum Recursion<V: ExpressionVisitor> {
Expand Down Expand Up @@ -103,6 +104,19 @@ impl ExprVisitable for Expr {
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
v.and_then(|v| {
exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
})
})
}
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
Expand Down
28 changes: 28 additions & 0 deletions datafusion/core/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::logical_plan::{
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::DataType;
use datafusion_expr::expr::GroupingSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

Expand Down Expand Up @@ -482,6 +483,33 @@ impl ExprIdentifierVisitor<'_> {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
desc.push_str("Rollup");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::Cube(exprs) => {
desc.push_str("Cube");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::GroupingSets(lists_of_exprs) => {
desc.push_str("GroupingSets");
for exprs in lists_of_exprs {
desc.push('(');
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
desc.push(')');
}
}
},
}

desc
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ mod tests {
// that the Column references are unqualified (e.g. their
// relation is `None`). PlanBuilder resolves the expressions
let expr = vec![col("a"), col("b")];
let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap();
let projected_fields = exprlist_to_fields(&expr, &table_scan).unwrap();
let projected_schema = DFSchema::new_with_metadata(
projected_fields,
input_schema.metadata().clone(),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl<'a> ConstEvaluator<'a> {
| Expr::ScalarSubquery(_)
| Expr::WindowFunction { .. }
| Expr::Sort { .. }
| Expr::GroupingSet(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. } => false,
Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()),
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
logical_plan::ExpressionVisitor,
};
use datafusion_common::DFSchema;
use datafusion_expr::expr::GroupingSet;
use std::{collections::HashSet, sync::Arc};

const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
Expand Down Expand Up @@ -83,6 +84,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
Expand Down Expand Up @@ -323,6 +325,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
GroupingSet::Cube(exprs) => Ok(exprs.clone()),
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::WindowFunction {
args,
partition_by,
Expand Down Expand Up @@ -458,6 +467,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::Cube(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::Case { .. } => {
let mut base_expr: Option<Box<Expr>> = None;
let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
Expand Down
32 changes: 32 additions & 0 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use async_trait::async_trait;
use datafusion_expr::expr::GroupingSet;
use datafusion_physical_expr::expressions::DateIntervalExpr;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -174,6 +175,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::Cube(exprs) => Ok(format!(
"CUBE ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::GroupingSets(lists_of_exprs) => {
let mut strings = vec![];
for exprs in lists_of_exprs {
let exprs_str = exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ");
strings.push(format!("({})", exprs_str));
}
Ok(format!("GROUPING SETS ({})", strings.join(", ")))
}
},

Expr::InList {
expr,
list,
Expand Down
Loading

0 comments on commit 1fe038f

Please sign in to comment.