diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 425c685f9351..3a5e5b56dae8 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1009,6 +1009,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| { let group_by_expr = self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?; + // aliases from the projection can conflict with same-named expressions in the input + let mut alias_map = alias_map.clone(); + for f in plan.schema().fields() { + alias_map.remove(f.name()); + } let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; let group_by_expr = resolve_positions_to_exprs(&group_by_expr, &select_exprs) @@ -1023,7 +1028,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; // process group by, aggregation or having - let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) = + let (plan, select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() { self.aggregate( plan, @@ -1051,7 +1056,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (plan, select_exprs, having_expr_opt) }; - let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr_opt { + let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr { LogicalPlanBuilder::from(plan) .filter(having_expr_post_aggr)? .build()? @@ -1110,7 +1115,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::from(input).project(expr)?.build() } - /// Wrap a plan in an aggregate + /// Create an aggregate plan. + /// + /// An aggregate plan consists of grouping expressions, aggregate expressions, and an + /// optional HAVING expression (which is a filter on the output of the aggregate). + /// + /// # Arguments + /// + /// * `input` - The input plan that will be aggregated. The grouping, aggregate, and + /// "having" expressions must all be resolvable from this plan. + /// * `select_exprs` - The projection expressions from the SELECT clause. + /// * `having_expr_opt` - Optional HAVING clause. + /// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column + /// references or more complex expressions. + /// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`. + /// + /// # Return + /// + /// The return value is a triplet of the following items: + /// + /// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate. + /// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from + /// the aggregate + /// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from + /// the aggregate fn aggregate( &self, input: LogicalPlan, @@ -1151,7 +1179,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Rewrite the HAVING expression to use the columns produced by the // aggregation. - let having_expr_post_aggr_opt = if let Some(having_expr) = having_expr_opt { + let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = rebase_expr(having_expr, &aggr_projection_exprs, &input)?; @@ -1166,7 +1194,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None }; - Ok((plan, select_exprs_post_aggr, having_expr_post_aggr_opt)) + Ok((plan, select_exprs_post_aggr, having_expr_post_aggr)) } /// Wrap a plan in a limit diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index fe7d0ee2cc5d..cd1fb316b76d 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -27,6 +27,7 @@ use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, ExpressionVisitor, Recursion}, }; +use datafusion_expr::expr::find_columns_referenced_by_expr; use std::collections::HashMap; /// Collect all deeply nested `Expr::AggregateFunction` and @@ -58,9 +59,13 @@ pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec { } /// Collect all deeply nested `Expr::Column`'s. They are returned in order of -/// appearance (depth first), with duplicates omitted. +/// appearance (depth first), and may contain duplicates. pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec { - find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_))) + exprs + .iter() + .flat_map(find_columns_referenced_by_expr) + .map(Expr::Column) + .collect() } /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that @@ -137,8 +142,16 @@ where /// Convert any `Expr` to an `Expr::Column`. pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { match expr { - Expr::Column(_) => Ok(expr.clone()), - _ => Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))), + Expr::Column(col) => { + let field = plan.schema().field_from_column(col)?; + Ok(Expr::Column(field.qualified_column())) + } + _ => { + // we should not be trying to create a name for the expression + // based on the input schema but this is the current behavior + // see https://github.com/apache/arrow-datafusion/issues/2456 + Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))) + } } } diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index a72b052376ad..41f2471f6c9e 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -232,6 +232,27 @@ async fn csv_query_group_by_avg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_by_with_aliases() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1 AS c12, avg(c12) AS c1 FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+---------------------+", + "| c12 | c1 |", + "+-----+---------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+-----+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4d88ed815b14..7e1adac430b0 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -251,6 +251,77 @@ pub enum Expr { QualifiedWildcard { qualifier: String }, } +/// Recursively find all columns referenced by an expression +pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec { + match e { + Expr::Alias(expr, _) + | Expr::Negative(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Sort { expr, .. } + | Expr::InList { expr, .. } + | Expr::InSubquery { expr, .. } + | Expr::GetIndexedField { expr, .. } + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) => find_columns_referenced_by_expr(expr), + Expr::Column(c) => vec![c.clone()], + Expr::BinaryExpr { left, right, .. } => { + let mut cols = vec![]; + cols.extend(find_columns_referenced_by_expr(left.as_ref())); + cols.extend(find_columns_referenced_by_expr(right.as_ref())); + cols + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut cols = vec![]; + if let Some(expr) = expr { + cols.extend(find_columns_referenced_by_expr(expr.as_ref())); + } + for (w, t) in when_then_expr { + cols.extend(find_columns_referenced_by_expr(w.as_ref())); + cols.extend(find_columns_referenced_by_expr(t.as_ref())); + } + if let Some(else_expr) = else_expr { + cols.extend(find_columns_referenced_by_expr(else_expr.as_ref())); + } + cols + } + Expr::ScalarFunction { args, .. } => args + .iter() + .flat_map(find_columns_referenced_by_expr) + .collect(), + Expr::AggregateFunction { args, .. } => args + .iter() + .flat_map(find_columns_referenced_by_expr) + .collect(), + Expr::ScalarVariable(_, _) + | Expr::Exists { .. } + | Expr::Wildcard + | Expr::QualifiedWildcard { .. } + | Expr::ScalarSubquery(_) + | Expr::Literal(_) => vec![], + Expr::Between { + expr, low, high, .. + } => { + let mut cols = vec![]; + cols.extend(find_columns_referenced_by_expr(expr.as_ref())); + cols.extend(find_columns_referenced_by_expr(low.as_ref())); + cols.extend(find_columns_referenced_by_expr(high.as_ref())); + cols + } + Expr::ScalarUDF { args, .. } + | Expr::WindowFunction { args, .. } + | Expr::AggregateUDF { args, .. } => args + .iter() + .flat_map(find_columns_referenced_by_expr) + .collect(), + } +} + /// Fixed seed for the hashing so that Ords are consistent across runs const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);