Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQL planner support for grouping() aggregate expressions #2486

Merged
merged 4 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::sql::utils::find_columns_referenced_by_expr;
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::logical_plan::Aggregate;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
Expand Down Expand Up @@ -136,35 +137,63 @@ pub fn create_udaf(
)
}

/// Find all columns referenced from an aggregate query
fn agg_cols(agg: &Aggregate) -> Result<Vec<Column>> {
Ok(agg
.aggr_expr
.iter()
.chain(&agg.group_expr)
.flat_map(find_columns_referenced_by_expr)
.collect())
}

fn exprlist_to_fields_aggregate(
exprs: &[Expr],
plan: &LogicalPlan,
agg: &Aggregate,
) -> Result<Vec<DFField>> {
let agg_cols = agg_cols(agg)?;
let mut fields = vec![];
for expr in exprs {
match expr {
Expr::Column(c) if agg_cols.iter().any(|x| x == c) => {
// resolve against schema of input to aggregate
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}

/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
match plan {
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
// when dealing with aggregate plans we cannot simply look in the aggregate output schema
// because it will contain columns representing complex expressions (such a column named
// `#GROUPING(person.state)` so in order to resolve `person.state` in this case we need to
// look at the input to the aggregate instead.
let fields = match plan {
LogicalPlan::Aggregate(agg) => {
let group_expr: Vec<Column> = agg
.group_expr
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect();
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let mut fields = vec![];
for expr in &exprs {
match expr {
Expr::Column(c) if group_expr.iter().any(|x| x == c) => {
// resolve against schema of input to aggregate
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}
_ => {
let input_schema = &plan.schema();
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
LogicalPlan::Window(window) => match window.input.as_ref() {
LogicalPlan::Aggregate(agg) => {
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
_ => None,
},
_ => None,
};
if let Some(fields) = fields {
fields
} else {
// look for exact match in plan's output schema
let input_schema = &plan.schema();
exprs.iter().map(|e| e.to_field(input_schema)).collect()
}
}

Expand Down
32 changes: 32 additions & 0 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4685,6 +4685,38 @@ mod tests {
quick_test(sql, expected);
}

#[tokio::test]
async fn aggregate_with_rollup_with_grouping() {
let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \
FROM person GROUP BY id, ROLLUP (state, age)";
let expected = "Projection: #person.id, #person.state, #person.age, #GROUPING(person.state), #GROUPING(person.age), #GROUPING(person.state) + #GROUPING(person.age), #COUNT(UInt8(1))\
\n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[GROUPING(#person.state), GROUPING(#person.age), COUNT(UInt8(1))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}

#[tokio::test]
async fn rank_partition_grouping() {
let sql = "select
sum(age) as total_sum,
state,
last_name,
grouping(state) + grouping(last_name) as x,
rank() over (
partition by grouping(state) + grouping(last_name),
case when grouping(last_name) = 0 then state end
order by sum(age) desc
) as the_rank
from
person
group by rollup(state, last_name)";
let expected = "Projection: #SUM(person.age) AS total_sum, #person.state, #person.last_name, #GROUPING(person.state) + #GROUPING(person.last_name) AS x, #RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST] AS the_rank\
\n WindowAggr: windowExpr=[[RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST]]]\
\n Aggregate: groupBy=[[ROLLUP (#person.state, #person.last_name)]], aggr=[[SUM(#person.age), GROUPING(#person.state), GROUPING(#person.last_name)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}

#[tokio::test]
async fn aggregate_with_cube() {
let sql =
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ pub enum AggregateFunction {
ApproxPercentileContWithWeight,
/// ApproxMedian
ApproxMedian,
/// Grouping
Grouping,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -121,6 +123,7 @@ impl FromStr for AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight
}
"approx_median" => AggregateFunction::ApproxMedian,
"grouping" => AggregateFunction::Grouping,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -173,6 +176,7 @@ pub fn return_type(
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
}
}

Expand Down Expand Up @@ -326,6 +330,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}

Expand All @@ -335,6 +340,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
match fun {
AggregateFunction::Count
| AggregateFunction::ApproxDistinct
| AggregateFunction::Grouping
| AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ pub fn create_aggregate_expr(
name,
return_type,
)),
(AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
coerced_phy_exprs[0].clone(),
name,
Expand Down
93 changes: 93 additions & 0 deletions datafusion/physical-expr/src/aggregate/grouping.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution

use std::any::Any;
use std::sync::Arc;

use crate::{AggregateExpr, PhysicalExpr};
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;

use crate::expressions::format_state_name;

/// GROUPING aggregate expression
/// Returns the amount of non-null values of the given expression.
#[derive(Debug)]
pub struct Grouping {
name: String,
data_type: DataType,
nullable: bool,
expr: Arc<dyn PhysicalExpr>,
}

impl Grouping {
/// Create a new GROUPING aggregate function.
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
nullable: true,
}
}
}

impl AggregateExpr for Grouping {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(
&self.name,
self.data_type.clone(),
self.nullable,
))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "grouping"),
self.data_type.clone(),
true,
)])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Err(DataFusionError::NotImplemented(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍
I see -- I was going to suggest adding a test to sql_integ but it isn't ready yet 👍

"physical plan is not yet implemented for GROUPING aggregate function"
.to_owned(),
))
}

fn name(&self) -> &str {
&self.name
}
}
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) mod correlation;
pub(crate) mod count;
pub(crate) mod count_distinct;
pub(crate) mod covariance;
pub(crate) mod grouping;
#[macro_use]
pub(crate) mod min_max;
pub mod build_in;
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub use crate::aggregate::correlation::Correlation;
pub use crate::aggregate::count::Count;
pub use crate::aggregate::count_distinct::DistinctCount;
pub use crate::aggregate::covariance::{Covariance, CovariancePop};
pub use crate::aggregate::grouping::Grouping;
pub use crate::aggregate::min_max::{Max, Min};
pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
pub use crate::aggregate::stats::StatsType;
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ enum AggregateFunction {
APPROX_PERCENTILE_CONT = 14;
APPROX_MEDIAN=15;
APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
GROUPING = 17;
}

message AggregateExprNode {
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
Self::ApproxPercentileContWithWeight
}
protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
protobuf::AggregateFunction::Grouping => Self::Grouping,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/src/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
Self::ApproxPercentileContWithWeight
}
AggregateFunction::ApproxMedian => Self::ApproxMedian,
AggregateFunction::Grouping => Self::Grouping,
}
}
}
Expand Down Expand Up @@ -541,6 +542,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
AggregateFunction::ApproxMedian => {
protobuf::AggregateFunction::ApproxMedian
}
AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping,
};

let aggregate_expr = protobuf::AggregateExprNode {
Expand Down