Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MINOR: Partial fix for SQL aggregate queries with aliases #2464

Merged
merged 6 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(|e| {
let group_by_expr =
self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?;
// aliases from the projection can conflict with same-named expressions in the input
let mut alias_map = alias_map.clone();
for f in plan.schema().fields() {
alias_map.remove(f.name());
}
Comment on lines +1013 to +1016
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the first part of the fix

let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?;
let group_by_expr =
resolve_positions_to_exprs(&group_by_expr, &select_exprs)
Expand All @@ -1023,7 +1028,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.collect::<Result<Vec<Expr>>>()?;

// process group by, aggregation or having
let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) =
let (plan, select_exprs_post_aggr, having_expr_post_aggr) =
if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
self.aggregate(
plan,
Expand Down Expand Up @@ -1051,7 +1056,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
(plan, select_exprs, having_expr_opt)
};

let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr_opt {
let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr {
LogicalPlanBuilder::from(plan)
.filter(having_expr_post_aggr)?
.build()?
Expand Down Expand Up @@ -1110,7 +1115,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
LogicalPlanBuilder::from(input).project(expr)?.build()
}

/// Wrap a plan in an aggregate
/// Create an aggregate plan.
///
/// An aggregate plan consists of grouping expressions, aggregate expressions, and an
/// optional HAVING expression (which is a filter on the output of the aggregate).
///
/// # Arguments
///
/// * `input` - The input plan that will be aggregated. The grouping, aggregate, and
/// "having" expressions must all be resolvable from this plan.
/// * `select_exprs` - The projection expressions from the SELECT clause.
/// * `having_expr_opt` - Optional HAVING clause.
/// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column
/// references or more complex expressions.
/// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`.
///
/// # Return
///
/// The return value is a triplet of the following items:
///
/// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate.
/// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from
/// the aggregate
/// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from
/// the aggregate
fn aggregate(
&self,
input: LogicalPlan,
Expand Down Expand Up @@ -1151,7 +1179,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

// Rewrite the HAVING expression to use the columns produced by the
// aggregation.
let having_expr_post_aggr_opt = if let Some(having_expr) = having_expr_opt {
let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt {
let having_expr_post_aggr =
rebase_expr(having_expr, &aggr_projection_exprs, &input)?;

Expand All @@ -1166,7 +1194,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
None
};

Ok((plan, select_exprs_post_aggr, having_expr_post_aggr_opt))
Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
}

/// Wrap a plan in a limit
Expand Down
21 changes: 17 additions & 4 deletions datafusion/core/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::{
error::{DataFusionError, Result},
logical_plan::{Column, ExpressionVisitor, Recursion},
};
use datafusion_expr::expr::find_columns_referenced_by_expr;
use std::collections::HashMap;

/// Collect all deeply nested `Expr::AggregateFunction` and
Expand Down Expand Up @@ -58,9 +59,13 @@ pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
}

/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
/// appearance (depth first), with duplicates omitted.
/// appearance (depth first), and may contain duplicates.
pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was only finding some expressions and was not recursing and finding them all

exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}

/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
Expand Down Expand Up @@ -137,8 +142,16 @@ where
/// Convert any `Expr` to an `Expr::Column`.
pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(_) => Ok(expr.clone()),
_ => Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))),
Expr::Column(col) => {
let field = plan.schema().field_from_column(col)?;
Ok(Expr::Column(field.qualified_column()))
}
_ => {
// we should not be trying to create a name for the expression
// based on the input schema but this is the current behavior
// see https://github.com/apache/arrow-datafusion/issues/2456
Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?)))
}
}
}

Expand Down
21 changes: 21 additions & 0 deletions datafusion/core/tests/sql/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ async fn csv_query_group_by_avg() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_with_aliases() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT c1 AS c12, avg(c12) AS c1 FROM aggregate_test_100 GROUP BY c1";
Copy link
Member Author

@andygrove andygrove May 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This originally failed with Aggregations require unique expression names but the expression \\\"AVG(#aggregate_test_100.c12)\\\" at position 0 and \\\"AVG(#aggregate_test_100.c12)\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them

let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+---------------------+",
"| c12 | c1 |",
"+-----+---------------------+",
"| a | 0.48754517466109415 |",
"| b | 0.41040709263815384 |",
"| c | 0.6600456536439784 |",
"| d | 0.48855379387549824 |",
"| e | 0.48600669271341534 |",
"+-----+---------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_int_count() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
71 changes: 71 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,77 @@ pub enum Expr {
QualifiedWildcard { qualifier: String },
}

/// Recursively find all columns referenced by an expression
pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code could use an ExprVisitor to avoid having to do the recursion itself (and potentially missing some case)

I took a crack at doing so #2471

match e {
Expr::Alias(expr, _)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::InList { expr, .. }
| Expr::InSubquery { expr, .. }
| Expr::GetIndexedField { expr, .. }
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsNull(expr) => find_columns_referenced_by_expr(expr),
Expr::Column(c) => vec![c.clone()],
Expr::BinaryExpr { left, right, .. } => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(left.as_ref()));
cols.extend(find_columns_referenced_by_expr(right.as_ref()));
cols
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let mut cols = vec![];
if let Some(expr) = expr {
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
}
for (w, t) in when_then_expr {
cols.extend(find_columns_referenced_by_expr(w.as_ref()));
cols.extend(find_columns_referenced_by_expr(t.as_ref()));
}
if let Some(else_expr) = else_expr {
cols.extend(find_columns_referenced_by_expr(else_expr.as_ref()));
}
cols
}
Expr::ScalarFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::AggregateFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::ScalarVariable(_, _)
| Expr::Exists { .. }
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::ScalarSubquery(_)
| Expr::Literal(_) => vec![],
Expr::Between {
expr, low, high, ..
} => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
cols.extend(find_columns_referenced_by_expr(low.as_ref()));
cols.extend(find_columns_referenced_by_expr(high.as_ref()));
cols
}
Expr::ScalarUDF { args, .. }
| Expr::WindowFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
}
}

/// Fixed seed for the hashing so that Ords are consistent across runs
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);

Expand Down