Skip to content

Commit

Permalink
MINOR: Partial fix for SQL aggregate queries with aliases (#2464)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored May 6, 2022
1 parent 522ea52 commit 22464f0
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 9 deletions.
38 changes: 33 additions & 5 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(|e| {
let group_by_expr =
self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?;
// aliases from the projection can conflict with same-named expressions in the input
let mut alias_map = alias_map.clone();
for f in plan.schema().fields() {
alias_map.remove(f.name());
}
let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?;
let group_by_expr =
resolve_positions_to_exprs(&group_by_expr, &select_exprs)
Expand All @@ -1020,7 +1025,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.collect::<Result<Vec<Expr>>>()?;

// process group by, aggregation or having
let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) =
let (plan, select_exprs_post_aggr, having_expr_post_aggr) =
if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
self.aggregate(
plan,
Expand Down Expand Up @@ -1048,7 +1053,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
(plan, select_exprs, having_expr_opt)
};

let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr_opt {
let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr {
LogicalPlanBuilder::from(plan)
.filter(having_expr_post_aggr)?
.build()?
Expand Down Expand Up @@ -1107,7 +1112,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
LogicalPlanBuilder::from(input).project(expr)?.build()
}

/// Wrap a plan in an aggregate
/// Create an aggregate plan.
///
/// An aggregate plan consists of grouping expressions, aggregate expressions, and an
/// optional HAVING expression (which is a filter on the output of the aggregate).
///
/// # Arguments
///
/// * `input` - The input plan that will be aggregated. The grouping, aggregate, and
/// "having" expressions must all be resolvable from this plan.
/// * `select_exprs` - The projection expressions from the SELECT clause.
/// * `having_expr_opt` - Optional HAVING clause.
/// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column
/// references or more complex expressions.
/// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`.
///
/// # Return
///
/// The return value is a triplet of the following items:
///
/// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate.
/// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from
/// the aggregate
/// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from
/// the aggregate
fn aggregate(
&self,
input: LogicalPlan,
Expand Down Expand Up @@ -1148,7 +1176,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

// Rewrite the HAVING expression to use the columns produced by the
// aggregation.
let having_expr_post_aggr_opt = if let Some(having_expr) = having_expr_opt {
let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt {
let having_expr_post_aggr =
rebase_expr(having_expr, &aggr_projection_exprs, &input)?;

Expand All @@ -1163,7 +1191,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
None
};

Ok((plan, select_exprs_post_aggr, having_expr_post_aggr_opt))
Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
}

/// Wrap a plan in a limit
Expand Down
21 changes: 17 additions & 4 deletions datafusion/core/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::{
error::{DataFusionError, Result},
logical_plan::{Column, ExpressionVisitor, Recursion},
};
use datafusion_expr::expr::find_columns_referenced_by_expr;
use std::collections::HashMap;

/// Collect all deeply nested `Expr::AggregateFunction` and
Expand Down Expand Up @@ -58,9 +59,13 @@ pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
}

/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
/// appearance (depth first), with duplicates omitted.
/// appearance (depth first), and may contain duplicates.
pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_)))
exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}

/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
Expand Down Expand Up @@ -137,8 +142,16 @@ where
/// Convert any `Expr` to an `Expr::Column`.
pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(_) => Ok(expr.clone()),
_ => Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))),
Expr::Column(col) => {
let field = plan.schema().field_from_column(col)?;
Ok(Expr::Column(field.qualified_column()))
}
_ => {
// we should not be trying to create a name for the expression
// based on the input schema but this is the current behavior
// see https://github.com/apache/arrow-datafusion/issues/2456
Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?)))
}
}
}

Expand Down
21 changes: 21 additions & 0 deletions datafusion/core/tests/sql/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ async fn csv_query_group_by_avg() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_with_aliases() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT c1 AS c12, avg(c12) AS c1 FROM aggregate_test_100 GROUP BY c1";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+---------------------+",
"| c12 | c1 |",
"+-----+---------------------+",
"| a | 0.48754517466109415 |",
"| b | 0.41040709263815384 |",
"| c | 0.6600456536439784 |",
"| d | 0.48855379387549824 |",
"| e | 0.48600669271341534 |",
"+-----+---------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_int_count() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
71 changes: 71 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,77 @@ pub enum Expr {
QualifiedWildcard { qualifier: String },
}

/// Recursively find all columns referenced by an expression
pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
match e {
Expr::Alias(expr, _)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::InList { expr, .. }
| Expr::InSubquery { expr, .. }
| Expr::GetIndexedField { expr, .. }
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsNull(expr) => find_columns_referenced_by_expr(expr),
Expr::Column(c) => vec![c.clone()],
Expr::BinaryExpr { left, right, .. } => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(left.as_ref()));
cols.extend(find_columns_referenced_by_expr(right.as_ref()));
cols
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let mut cols = vec![];
if let Some(expr) = expr {
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
}
for (w, t) in when_then_expr {
cols.extend(find_columns_referenced_by_expr(w.as_ref()));
cols.extend(find_columns_referenced_by_expr(t.as_ref()));
}
if let Some(else_expr) = else_expr {
cols.extend(find_columns_referenced_by_expr(else_expr.as_ref()));
}
cols
}
Expr::ScalarFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::AggregateFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::ScalarVariable(_, _)
| Expr::Exists { .. }
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::ScalarSubquery(_)
| Expr::Literal(_) => vec![],
Expr::Between {
expr, low, high, ..
} => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
cols.extend(find_columns_referenced_by_expr(low.as_ref()));
cols.extend(find_columns_referenced_by_expr(high.as_ref()));
cols
}
Expr::ScalarUDF { args, .. }
| Expr::WindowFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
}
}

/// Fixed seed for the hashing so that Ords are consistent across runs
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);

Expand Down

0 comments on commit 22464f0

Please sign in to comment.