From a82ad9eb281504f2b7fb4fc5df80760ac6da800b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 12:56:23 -0400 Subject: [PATCH] Consolidate and better tests for expression re-rewriting / aliasing --- datafusion/expr/src/expr.rs | 4 +- datafusion/optimizer/src/type_coercion.rs | 33 ++--- .../src/unwrap_cast_in_comparison.rs | 40 +----- datafusion/optimizer/src/utils.rs | 122 +++++++++++++++++- 4 files changed, 136 insertions(+), 63 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 008a2c454d83..4e90971afa39 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -463,8 +463,8 @@ impl Expr { } /// Return `self AS name` alias expression - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), name.to_owned()) + pub fn alias(self, name: impl Into) -> Expr { + Expr::Alias(Box::new(self), name.into()) } /// Return `self IN ` if `negated` is false, otherwise diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 2073713ddd31..6e0291d22103 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -17,11 +17,12 @@ //! Optimizer rule for type validation and coercion +use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::binary_rule::{coerce_types, comparison_coercion}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::data_types; use datafusion_expr::utils::from_plan; @@ -87,30 +88,13 @@ fn optimize_internal( schema: Arc::new(schema), }; - let original_expr_names: Vec> = plan - .expressions() - .iter() - .map(|expr| expr.name().ok()) - .collect(); - let new_expr = plan .expressions() .into_iter() - .zip(original_expr_names) - .map(|(expr, original_name)| { - let expr = expr.rewrite(&mut expr_rewrite)?; - + .map(|expr| { // ensure aggregate names don't change: // https://github.com/apache/arrow-datafusion/issues/3555 - if matches!(expr, Expr::AggregateFunction { .. }) { - if let Some((alias, name)) = original_name.zip(expr.name().ok()) { - if alias != name { - return Ok(expr.alias(&alias)); - } - } - } - - Ok(expr) + rewrite_preserving_name(expr, &mut expr_rewrite) }) .collect::>>()?; @@ -637,7 +621,8 @@ mod test { let mut config = OptimizerConfig::default(); let plan = rule.optimize(&plan, &mut config)?; assert_eq!( - "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation", + "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ + \n EmptyRelation", &format!("{:?}", plan) ); // a in (1,4,8), a is decimal @@ -655,7 +640,8 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); let plan = rule.optimize(&plan, &mut config)?; assert_eq!( - "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation", + "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ + \n EmptyRelation", &format!("{:?}", plan) ); Ok(()) @@ -753,7 +739,8 @@ mod test { let mut config = OptimizerConfig::default(); let plan = rule.optimize(&plan, &mut config).unwrap(); assert_eq!( - "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation", + "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \ + \n EmptyRelation", &format!("{:?}", plan) ); diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 542c29bd7767..b9135768014d 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,12 +18,13 @@ //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, @@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_exprs = plan .expressions() .into_iter() - .map(|expr| { - let original_name = name_for_alias(&expr)?; - let expr = expr.rewrite(&mut expr_rewriter)?; - add_alias_if_changed(&original_name, expr) - }) + .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -fn name_for_alias(expr: &Expr) -> Result { - match expr { - Expr::Sort { expr, .. } => name_for_alias(expr), - expr => expr.name(), - } -} - -fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result { - let new_name = name_for_alias(&expr)?; - - if new_name == original_name { - return Ok(expr); - } - - Ok(match expr { - Expr::Sort { - expr, - asc, - nulls_first, - } => { - let expr = add_alias_if_changed(original_name, *expr)?; - Expr::Sort { - expr: Box::new(expr), - asc, - nulls_first, - } - } - expr => expr.alias(original_name), - }) -} - struct UnwrapCastExprRewriter { schema: DFSchemaRef, } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d962dd7b45b9..a1174276d546 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -20,6 +20,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, combine_filters, @@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec { .collect() } +/// Rewrites `expr` using `rewriter`, ensuring that the output has the +/// same name as `expr` prior to rewrite, adding an alias if necessary. +/// +/// This is important when optimzing plans to ensure the the output +/// schema of plan nodes don't change after optimization +pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +where + R: ExprRewriter, +{ + let original_name = name_for_alias(&expr)?; + let expr = expr.rewrite(rewriter)?; + add_alias_if_changed(original_name, expr) +} + +/// Return the name to use for the specific Expr, recursing into +/// `Expr::Sort` as appropriate +fn name_for_alias(expr: &Expr) -> Result { + match expr { + Expr::Sort { expr, .. } => name_for_alias(expr), + expr => expr.name(), + } +} + +/// Ensure `expr` has the name name as `original_name` by adding an +/// alias if necessary. +fn add_alias_if_changed(original_name: String, expr: Expr) -> Result { + let new_name = name_for_alias(&expr)?; + + if new_name == original_name { + return Ok(expr); + } + + Ok(match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let expr = add_alias_if_changed(original_name, *expr)?; + Expr::Sort { + expr: Box::new(expr), + asc, + nulls_first, + } + } + expr => expr.alias(original_name), + }) +} + #[cfg(test)] mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, utils::expr_to_columns}; + use datafusion_expr::{col, lit, utils::expr_to_columns}; use std::collections::HashSet; + use std::ops::Add; #[test] fn test_collect_expr() -> Result<()> { @@ -344,4 +395,73 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_rewrite_preserving_name() { + test_rewrite(col("a"), col("a")); + + test_rewrite(col("a"), col("b")); + + // cast data types + test_rewrite( + col("a"), + Expr::Cast { + expr: Box::new(col("a")), + data_type: DataType::Int32, + }, + ); + + // change literal type from i32 to i64 + test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); + + // SortExpr a+1 ==> b + 2 + test_rewrite( + Expr::Sort { + expr: Box::new(col("a").add(lit(1i32))), + asc: true, + nulls_first: false, + }, + Expr::Sort { + expr: Box::new(col("b").add(lit(2i64))), + asc: true, + nulls_first: false, + }, + ); + } + + /// rewrites `expr_from` to `rewrite_to` using + /// `rewrite_preserving_name` verifying the result is `expected_expr` + fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { + struct TestRewriter { + rewrite_to: Expr, + } + + impl ExprRewriter for TestRewriter { + fn mutate(&mut self, _: Expr) -> Result { + Ok(self.rewrite_to.clone()) + } + } + + let mut rewriter = TestRewriter { + rewrite_to: rewrite_to.clone(), + }; + let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + + let original_name = match &expr_from { + Expr::Sort { expr, .. } => expr.name(), + expr => expr.name(), + } + .unwrap(); + + let new_name = match &expr { + Expr::Sort { expr, .. } => expr.name(), + expr => expr.name(), + } + .unwrap(); + + assert_eq!( + original_name, new_name, + "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" + ) + } }