diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 9518986a14da..11a91f2eeaa8 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -96,6 +96,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::InSubquery { .. } | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } + | Expr::GroupingSet(_) | Expr::Case { .. } => Recursion::Continue(self), Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()), diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 8a0ea6d6667f..80ebfadd1628 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -43,7 +43,8 @@ use std::{ sync::Arc, }; -use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType}; +use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType}; +use crate::logical_plan::expr::exprlist_to_fields; use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, provider_as_source, rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, @@ -557,7 +558,7 @@ impl LogicalPlanBuilder { expr.extend(missing_exprs); let new_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&expr, input_schema)?, + exprlist_to_fields(&expr, &input)?, input_schema.metadata().clone(), )?; @@ -629,7 +630,7 @@ impl LogicalPlanBuilder { .map(|f| Expr::Column(f.qualified_column())) .collect(); let new_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&new_expr, schema)?, + exprlist_to_fields(&new_expr, &self.plan)?, schema.metadata().clone(), )?; @@ -843,8 +844,7 @@ impl LogicalPlanBuilder { let window_expr = normalize_cols(window_expr, &self.plan)?; let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; - let mut window_fields: Vec = - exprlist_to_fields(all_expr, self.plan.schema())?; + let mut window_fields: Vec = exprlist_to_fields(all_expr, &self.plan)?; window_fields.extend_from_slice(self.plan.schema().fields()); Ok(Self::from(LogicalPlan::Window(Window { input: Arc::new(self.plan.clone()), @@ -869,7 +869,7 @@ impl LogicalPlanBuilder { let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; let aggr_schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, self.plan.schema())?, + exprlist_to_fields(all_expr, &self.plan)?, self.plan.schema().metadata().clone(), )?; Ok(Self::from(LogicalPlan::Aggregate(Aggregate { @@ -1126,13 +1126,14 @@ pub fn project_with_alias( } validate_unique_names("Projections", projected_expr.iter(), input_schema)?; let input_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projected_expr, input_schema)?, + exprlist_to_fields(&projected_expr, &plan)?, plan.schema().metadata().clone(), )?; let schema = match alias { Some(ref alias) => input_schema.replace_qualifier(alias.as_str()), None => input_schema, }; + Ok(LogicalPlan::Projection(Projection { expr: projected_expr, input: Arc::new(plan.clone()), diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs index 673345c69b61..3ffc1894e554 100644 --- a/datafusion/core/src/logical_plan/expr.rs +++ b/datafusion/core/src/logical_plan/expr.rs @@ -22,14 +22,15 @@ pub use super::Operator; use crate::error::Result; use crate::logical_plan::ExprSchemable; use crate::logical_plan::{DFField, DFSchema}; +use crate::sql::utils::find_columns_referenced_by_expr; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; pub use datafusion_expr::expr_fn::*; -use datafusion_expr::AccumulatorFunctionImplementation; use datafusion_expr::BuiltinScalarFunction; pub use datafusion_expr::Expr; use datafusion_expr::StateTypeFunction; pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; +use datafusion_expr::{AccumulatorFunctionImplementation, LogicalPlan}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use datafusion_expr::{ ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, @@ -138,9 +139,33 @@ pub fn create_udaf( /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, - input_schema: &DFSchema, + plan: &LogicalPlan, ) -> Result> { - expr.into_iter().map(|e| e.to_field(input_schema)).collect() + match plan { + LogicalPlan::Aggregate(agg) => { + let group_expr: Vec = agg + .group_expr + .iter() + .flat_map(find_columns_referenced_by_expr) + .collect(); + let exprs: Vec = expr.into_iter().cloned().collect(); + let mut fields = vec![]; + for expr in &exprs { + match expr { + Expr::Column(c) if group_expr.iter().any(|x| x == c) => { + // resolve against schema of input to aggregate + fields.push(expr.to_field(agg.input.schema())?); + } + _ => fields.push(expr.to_field(plan.schema())?), + } + } + Ok(fields) + } + _ => { + let input_schema = &plan.schema(); + expr.into_iter().map(|e| e.to_field(input_schema)).collect() + } + } } /// Calls a named built in function diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index 4e94768993d5..1f24556eaa80 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -24,6 +24,7 @@ use crate::logical_plan::ExprSchemable; use crate::logical_plan::LogicalPlan; use datafusion_common::Column; use datafusion_common::Result; +use datafusion_expr::expr::GroupingSet; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -215,6 +216,22 @@ impl ExprRewritable for Expr { fun, distinct, }, + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => { + Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?)) + } + GroupingSet::Cube(exprs) => { + Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?)) + } + GroupingSet::GroupingSets(lists_of_exprs) => { + Expr::GroupingSet(GroupingSet::GroupingSets( + lists_of_exprs + .iter() + .map(|exprs| rewrite_vec(exprs.clone(), rewriter)) + .collect::>>()?, + )) + } + }, Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { args: rewrite_vec(args, rewriter)?, fun, diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 7c578da19b75..24acb65bcbab 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -19,6 +19,7 @@ use super::Expr; use datafusion_common::Result; +use datafusion_expr::expr::GroupingSet; /// Controls how the visitor recursion should proceed. pub enum Recursion { @@ -103,6 +104,19 @@ impl ExprVisitable for Expr { | Expr::TryCast { expr, .. } | Expr::Sort { expr, .. } | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs + .iter() + .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), + Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs + .iter() + .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| { + v.and_then(|v| { + exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v))) + }) + }) + } Expr::Column(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index a9983cdf1e08..967ef58b39c4 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -29,6 +29,7 @@ use crate::logical_plan::{ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use arrow::datatypes::DataType; +use datafusion_expr::expr::GroupingSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -482,6 +483,33 @@ impl ExprIdentifierVisitor<'_> { desc.push_str("GetIndexedField-"); desc.push_str(&key.to_string()); } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => { + desc.push_str("Rollup"); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + } + GroupingSet::Cube(exprs) => { + desc.push_str("Cube"); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + } + GroupingSet::GroupingSets(lists_of_exprs) => { + desc.push_str("GroupingSets"); + for exprs in lists_of_exprs { + desc.push('('); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + desc.push(')'); + } + } + }, } desc diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 5062082e8643..0979d8f5b218 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -810,7 +810,7 @@ mod tests { // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions let expr = vec![col("a"), col("b")]; - let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap(); + let projected_fields = exprlist_to_fields(&expr, &table_scan).unwrap(); let projected_schema = DFSchema::new_with_metadata( projected_fields, input_schema.metadata().clone(), diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 4dfbb6eb6543..e9694ebc528c 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -380,6 +380,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::Sort { .. } + | Expr::GroupingSet(_) | Expr::Wildcard | Expr::QualifiedWildcard { .. } => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 48855df9f8e8..2c56b5f893c3 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -36,6 +36,7 @@ use crate::{ logical_plan::ExpressionVisitor, }; use datafusion_common::DFSchema; +use datafusion_expr::expr::GroupingSet; use std::{collections::HashSet, sync::Arc}; const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; @@ -83,6 +84,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::ScalarUDF { .. } | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } + | Expr::GroupingSet(_) | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } @@ -323,6 +325,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { | Expr::ScalarUDF { args, .. } | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => Ok(args.clone()), + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(exprs.clone()), + GroupingSet::Cube(exprs) => Ok(exprs.clone()), + GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan( + "GroupingSets are not supported yet".to_string(), + )), + }, Expr::WindowFunction { args, partition_by, @@ -458,6 +467,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(_exprs) => { + Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec()))) + } + GroupingSet::Cube(_exprs) => { + Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec()))) + } + GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan( + "GroupingSets are not supported yet".to_string(), + )), + }, Expr::Case { .. } => { let mut base_expr: Option> = None; let mut when_then: Vec<(Box, Box)> = vec![]; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 85fb7d424fac..f6b3842f243e 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -62,6 +62,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use arrow::{compute::can_cast_types, datatypes::DataType}; use async_trait::async_trait; +use datafusion_expr::expr::GroupingSet; use datafusion_physical_expr::expressions::DateIntervalExpr; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; @@ -174,6 +175,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(format!( + "ROLLUP ({})", + exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", ") + )), + GroupingSet::Cube(exprs) => Ok(format!( + "CUBE ({})", + exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", ") + )), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut strings = vec![]; + for exprs in lists_of_exprs { + let exprs_str = exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", "); + strings.push(format!("({})", exprs_str)); + } + Ok(format!("GROUPING SETS ({})", strings.join(", "))) + } + }, + Expr::InList { expr, list, diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 33391d91e86c..af8329018f67 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -50,6 +50,7 @@ use datafusion_expr::{window_function::WindowFunction, BuiltinScalarFunction}; use hashbrown::HashMap; use datafusion_common::field_not_found; +use datafusion_expr::expr::GroupingSet; use datafusion_expr::logical_plan::{Filter, Subquery}; use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, @@ -1156,11 +1157,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) - let aggr_projection_exprs = group_by_exprs - .iter() - .chain(aggr_exprs.iter()) - .cloned() - .collect::>(); + let mut aggr_projection_exprs = vec![]; + for expr in &group_by_exprs { + match expr { + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + aggr_projection_exprs.extend_from_slice(exprs) + } + Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + aggr_projection_exprs.extend_from_slice(exprs) + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + for exprs in lists_of_exprs { + aggr_projection_exprs.extend_from_slice(exprs) + } + } + _ => aggr_projection_exprs.push(expr.clone()), + } + } + aggr_projection_exprs.extend_from_slice(&aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs @@ -1885,10 +1899,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { normalize_ident(&function.name.0[0]) }; - // first, scalar built-in - if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { + // first, check SQL reserved words + if name == "rollup" { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::GroupingSet(GroupingSet::Rollup(args))); + } else if name == "cube" { let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::GroupingSet(GroupingSet::Cube(args))); + } + // next, scalar built-in + if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { + let args = self.function_args_to_expr(function.args, schema)?; return Ok(Expr::ScalarFunction { fun, args }); }; @@ -4654,6 +4676,33 @@ mod tests { quick_test(sql, &expected) } + #[tokio::test] + async fn aggregate_with_rollup() { + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; + let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[tokio::test] + async fn aggregate_with_cube() { + let sql = + "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; + let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.id, CUBE (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2469 + #[tokio::test] + async fn aggregate_with_grouping_sets() { + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; + let expected = "TBD"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index 0293e241023d..b2cf1f6987e1 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -27,6 +27,7 @@ use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, ExpressionVisitor, Recursion}, }; +use datafusion_expr::expr::GroupingSet; use std::collections::HashMap; /// Collect all deeply nested `Expr::AggregateFunction` and @@ -100,7 +101,7 @@ impl ExpressionVisitor for ColumnCollector { } } -fn find_columns_referenced_by_expr(e: &Expr) -> Vec { +pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { // As the `ExpressionVisitor` impl above always returns Ok, this // "can't" error let ColumnCollector { exprs } = e @@ -235,22 +236,49 @@ pub(crate) fn check_columns_satisfy_exprs( "Expr::Column are required".to_string(), )), })?; - - for e in &find_column_exprs(exprs) { - if !columns.contains(e) { - return Err(DataFusionError::Plan(format!( - "{}: Expression {:?} could not be resolved from available columns: {}", - message_prefix, - e, - columns - .iter() - .map(|e| format!("{}", e)) - .collect::>() - .join(", ") - ))); + let column_exprs = find_column_exprs(exprs); + for e in &column_exprs { + match e { + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + for exprs in lists_of_exprs { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + } + _ => check_column_satisfies_expr(columns, e, message_prefix)?, } } + Ok(()) +} +fn check_column_satisfies_expr( + columns: &[Expr], + expr: &Expr, + message_prefix: &str, +) -> Result<()> { + if !columns.contains(expr) { + return Err(DataFusionError::Plan(format!( + "{}: Expression {:?} could not be resolved from available columns: {}", + message_prefix, + expr, + columns + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ))); + } Ok(()) } @@ -456,6 +484,34 @@ where expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), key: key.clone(), }), + Expr::GroupingSet(set) => match set { + GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ))), + GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ))), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut new_lists_of_exprs = vec![]; + for exprs in lists_of_exprs { + new_lists_of_exprs.push( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ); + } + Ok(Expr::GroupingSet(GroupingSet::GroupingSets( + new_lists_of_exprs, + ))) + } + }, }, } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4d88ed815b14..c1c61d1ff049 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -249,6 +249,24 @@ pub enum Expr { Wildcard, /// Represents a reference to all fields in a specific schema. QualifiedWildcard { qualifier: String }, + /// List of grouping set expressions. Only valid in the context of an aggregate + /// GROUP BY expression list + GroupingSet(GroupingSet), +} + +/// Grouping sets +/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS +/// for Postgres definition. +/// See https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html +/// for Apache Spark definition. +#[derive(Clone, PartialEq, Hash)] +pub enum GroupingSet { + /// Rollup grouping sets + Rollup(Vec), + /// Cube grouping sets + Cube(Vec), + /// User-defined grouping sets + GroupingSets(Vec>), } /// Fixed seed for the hashing so that Ords are consistent across runs @@ -556,6 +574,51 @@ impl fmt::Debug for Expr { Expr::GetIndexedField { ref expr, key } => { write!(f, "({:?})[{}]", expr, key) } + Expr::GroupingSet(grouping_sets) => match grouping_sets { + GroupingSet::Rollup(exprs) => { + // ROLLUP (c0, c1, c2) + write!( + f, + "ROLLUP ({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } + GroupingSet::Cube(exprs) => { + // CUBE (c0, c1, c2) + write!( + f, + "CUBE ({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } + GroupingSet::GroupingSets(lists_of_exprs) => { + // GROUPING SETS ((c0), (c1, c2), (c3, c4)) + write!( + f, + "GROUPING SETS ({})", + lists_of_exprs + .iter() + .map(|exprs| format!( + "({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + )) + .collect::>() + .join(", ") + ) + } + }, } } } @@ -710,6 +773,26 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(format!( + "ROLLUP ({})", + create_names(exprs.as_slice(), input_schema)? + )), + GroupingSet::Cube(exprs) => Ok(format!( + "CUBE ({})", + create_names(exprs.as_slice(), input_schema)? + )), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut list_of_names = vec![]; + for exprs in lists_of_exprs { + list_of_names.push(format!( + "({})", + create_names(exprs.as_slice(), input_schema)? + )); + } + Ok(format!("GROUPING SETS ({})", list_of_names.join(", "))) + } + }, Expr::InList { expr, list, @@ -750,6 +833,15 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } } +/// Create a comma separated list of names from a list of expressions +fn create_names(exprs: &[Expr], input_schema: &DFSchema) -> Result { + Ok(exprs + .iter() + .map(|e| create_name(e, input_schema)) + .collect::>>()? + .join(", ")) +} + #[cfg(test)] mod test { use crate::expr_fn::col; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b932eefa0b96..2433024e38a4 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -124,6 +124,10 @@ impl ExprSchemable for Expr { "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), )), + Expr::GroupingSet(_) => { + // grouping sets do not really have a type and do not appear in projections + Ok(DataType::Null) + } Expr::GetIndexedField { ref expr, key } => { let data_type = expr.get_type(schema)?; @@ -198,6 +202,11 @@ impl ExprSchemable for Expr { let data_type = expr.get_type(input_schema)?; get_indexed_field(&data_type, key).map(|x| x.is_nullable()) } + Expr::GroupingSet(_) => { + // grouping sets do not really have the concept of nullable and do not appear + // in projections + Ok(true) + } } }