Skip to content

Commit

Permalink
Add SQL planner support for ROLLUP and CUBE grouping sets
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed May 6, 2022
1 parent 22464f0 commit fd08e9a
Show file tree
Hide file tree
Showing 17 changed files with 593 additions and 174 deletions.
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Recursion::Continue(self),

Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()),
Expand Down
15 changes: 8 additions & 7 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use std::{
sync::Arc,
};

use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::expr::exprlist_to_fields;
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, provider_as_source,
rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, DFSchemaRef, Limit,
Expand Down Expand Up @@ -557,7 +558,7 @@ impl LogicalPlanBuilder {
expr.extend(missing_exprs);

let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&expr, input_schema)?,
exprlist_to_fields(&expr, &input)?,
input_schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -629,7 +630,7 @@ impl LogicalPlanBuilder {
.map(|f| Expr::Column(f.qualified_column()))
.collect();
let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&new_expr, schema)?,
exprlist_to_fields(&new_expr, &self.plan)?,
schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -843,8 +844,7 @@ impl LogicalPlanBuilder {
let window_expr = normalize_cols(window_expr, &self.plan)?;
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
let mut window_fields: Vec<DFField> =
exprlist_to_fields(all_expr, self.plan.schema())?;
let mut window_fields: Vec<DFField> = exprlist_to_fields(all_expr, &self.plan)?;
window_fields.extend_from_slice(self.plan.schema().fields());
Ok(Self::from(LogicalPlan::Window(Window {
input: Arc::new(self.plan.clone()),
Expand All @@ -869,7 +869,7 @@ impl LogicalPlanBuilder {
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
let aggr_schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, self.plan.schema())?,
exprlist_to_fields(all_expr, &self.plan)?,
self.plan.schema().metadata().clone(),
)?;
Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
Expand Down Expand Up @@ -1126,13 +1126,14 @@ pub fn project_with_alias(
}
validate_unique_names("Projections", projected_expr.iter(), input_schema)?;
let input_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&projected_expr, input_schema)?,
exprlist_to_fields(&projected_expr, &plan)?,
plan.schema().metadata().clone(),
)?;
let schema = match alias {
Some(ref alias) => input_schema.replace_qualifier(alias.as_str()),
None => input_schema,
};

Ok(LogicalPlan::Projection(Projection {
expr: projected_expr,
input: Arc::new(plan.clone()),
Expand Down
38 changes: 37 additions & 1 deletion datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use crate::logical_plan::{DFField, DFSchema};
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{AccumulatorFunctionImplementation, LogicalPlan};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
Expand Down Expand Up @@ -137,6 +137,42 @@ pub fn create_udaf(

/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let mut fields = vec![];
for expr in &exprs {
match expr {
Expr::Column(c) => {
match plan {
LogicalPlan::Aggregate(agg) => {
let group_expr = agg.columns_in_group_expr()?;
if let Some(_) = group_expr.into_iter().find(|x| x == c) {
// fall back to legacy behavior, which has known issues, but at least use valid expressions and schemas
fields.push(expr.to_field(&agg.input.schema())?);
} else {
// fall back to legacy behavior, which has known issues
fields.push(expr.to_field(plan.schema())?);
}
}
_ => {
// fall back to legacy behavior, which has known issues
fields.push(expr.to_field(plan.schema())?);
}
}
}
_ => {
// fall back to legacy behavior, which has known issues
fields.push(expr.to_field(&plan.schema())?);
}
}
}
Ok(fields)
}

/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields_from_schema<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
input_schema: &DFSchema,
) -> Result<Vec<DFField>> {
Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::logical_plan::ExprSchemable;
use crate::logical_plan::LogicalPlan;
use datafusion_common::Column;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -215,6 +216,22 @@ impl ExprRewritable for Expr {
fun,
distinct,
},
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::Cube(exprs) => {
Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
Expr::GroupingSet(GroupingSet::GroupingSets(
lists_of_exprs
.iter()
.map(|exprs| rewrite_vec(exprs.clone(), rewriter))
.collect::<Result<Vec<_>>>()?,
))
}
},
Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
Expand Down
14 changes: 14 additions & 0 deletions datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use super::Expr;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;

/// Controls how the visitor recursion should proceed.
pub enum Recursion<V: ExpressionVisitor> {
Expand Down Expand Up @@ -103,6 +104,19 @@ impl ExprVisitable for Expr {
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
v.and_then(|v| {
exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
})
})
}
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ pub use expr::{
avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col,
columnize_expr, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos,
count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest,
exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, length,
lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, power, random,
regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, to_timestamp_micros,
exists, exp, exprlist_to_fields_from_schema, floor, in_list, in_subquery, initcap,
left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5,
min, not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, power,
random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part,
sqrt, starts_with, strpos, substr, sum, tan, to_hex, to_timestamp_micros,
to_timestamp_millis, to_timestamp_seconds, translate, trim, trunc, unalias, upper,
when, Column, Expr, ExprSchema, Literal,
};
Expand Down
28 changes: 28 additions & 0 deletions datafusion/core/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::logical_plan::{
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::DataType;
use datafusion_expr::expr::GroupingSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

Expand Down Expand Up @@ -482,6 +483,33 @@ impl ExprIdentifierVisitor<'_> {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
desc.push_str("Rollup");
for expr in exprs {
desc.push_str("-");
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::Cube(exprs) => {
desc.push_str("Cube");
for expr in exprs {
desc.push_str("-");
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::GroupingSets(lists_of_exprs) => {
desc.push_str("GroupingSets");
for exprs in lists_of_exprs {
desc.push_str("(");
for expr in exprs {
desc.push_str("-");
desc.push_str(&Self::desc_expr(expr));
}
desc.push_str(")");
}
}
},
}

desc
Expand Down
6 changes: 4 additions & 2 deletions datafusion/core/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ mod tests {

use super::*;
use crate::logical_plan::{
col, exprlist_to_fields, lit, max, min, Expr, JoinType, LogicalPlanBuilder,
col, exprlist_to_fields_from_schema, lit, max, min, Expr, JoinType,
LogicalPlanBuilder,
};
use crate::test::*;
use arrow::datatypes::DataType;
Expand Down Expand Up @@ -810,7 +811,8 @@ mod tests {
// that the Column references are unqualified (e.g. their
// relation is `None`). PlanBuilder resolves the expressions
let expr = vec![col("a"), col("b")];
let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap();
let projected_fields =
exprlist_to_fields_from_schema(&expr, input_schema).unwrap();
let projected_schema = DFSchema::new_with_metadata(
projected_fields,
input_schema.metadata().clone(),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl<'a> ConstEvaluator<'a> {
| Expr::ScalarSubquery(_)
| Expr::WindowFunction { .. }
| Expr::Sort { .. }
| Expr::GroupingSet(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. } => false,
Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()),
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
logical_plan::ExpressionVisitor,
};
use datafusion_common::DFSchema;
use datafusion_expr::expr::GroupingSet;
use std::{collections::HashSet, sync::Arc};

const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
Expand Down Expand Up @@ -83,6 +84,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
Expand Down Expand Up @@ -323,6 +325,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
GroupingSet::Cube(exprs) => Ok(exprs.clone()),
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::WindowFunction {
args,
partition_by,
Expand Down Expand Up @@ -458,6 +467,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::Cube(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::Case { .. } => {
let mut base_expr: Option<Box<Expr>> = None;
let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
Expand Down
32 changes: 32 additions & 0 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use async_trait::async_trait;
use datafusion_expr::expr::GroupingSet;
use datafusion_physical_expr::expressions::DateIntervalExpr;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -174,6 +175,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::Cube(exprs) => Ok(format!(
"CUBE ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::GroupingSets(lists_of_exprs) => {
let mut strings = vec![];
for exprs in lists_of_exprs {
let exprs_str = exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ");
strings.push(format!("({})", exprs_str));
}
Ok(format!("GROUPING SETS ({})", strings.join(", ")))
}
},

Expr::InList {
expr,
list,
Expand Down
Loading

0 comments on commit fd08e9a

Please sign in to comment.