Skip to content

Commit

Permalink
Consolidate and better tests for expression re-rewriting / aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Oct 5, 2022
1 parent f6e8ffa commit cec55d1
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 63 deletions.
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ impl Expr {
}

/// Return `self AS name` alias expression
pub fn alias(self, name: &str) -> Expr {
Expr::Alias(Box::new(self), name.to_owned())
pub fn alias(self, name: impl Into<String>) -> Expr {
Expr::Alias(Box::new(self), name.into())
}

/// Return `self IN <list>` if `negated` is false, otherwise
Expand Down
33 changes: 10 additions & 23 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

//! Optimizer rule for type validation and coercion
use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
Expand Down Expand Up @@ -87,30 +88,13 @@ fn optimize_internal(
schema: Arc::new(schema),
};

let original_expr_names: Vec<Option<String>> = plan
.expressions()
.iter()
.map(|expr| expr.name().ok())
.collect();

let new_expr = plan
.expressions()
.into_iter()
.zip(original_expr_names)
.map(|(expr, original_name)| {
let expr = expr.rewrite(&mut expr_rewrite)?;

.map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
if matches!(expr, Expr::AggregateFunction { .. }) {
if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
if alias != name {
return Ok(expr.alias(&alias));
}
}
}

Ok(expr)
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -637,7 +621,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation",
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
\n EmptyRelation",
&format!("{:?}", plan)
);
// a in (1,4,8), a is decimal
Expand All @@ -655,7 +640,8 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation",
"Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
Expand Down Expand Up @@ -753,7 +739,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config).unwrap();
assert_eq!(
"Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation",
"Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \
\n EmptyRelation",
&format!("{:?}", plan)
);

Expand Down
40 changes: 3 additions & 37 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type
//! of expr can be added if needed.
//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
Expand Down Expand Up @@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| {
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(&mut expr_rewriter)?;
add_alias_if_changed(&original_name, expr)
})
.map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
Expand Down
122 changes: 121 additions & 1 deletion datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
and, col, combine_filters,
Expand Down Expand Up @@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
.collect()
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimzing plans to ensure the the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: ExprRewriter<Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

/// Ensure `expr` has the name name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::{col, utils::expr_to_columns};
use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;

#[test]
fn test_collect_expr() -> Result<()> {
Expand All @@ -344,4 +395,73 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Int32,
},
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort {
expr: Box::new(col("a").add(lit(1i32))),
asc: true,
nulls_first: false,
},
Expr::Sort {
expr: Box::new(col("b").add(lit(2i64))),
asc: true,
nulls_first: false,
},
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl ExprRewriter for TestRewriter {
fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort { expr, .. } => expr.name(),
expr => expr.name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort { expr, .. } => expr.name(),
expr => expr.name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}

0 comments on commit cec55d1

Please sign in to comment.