Skip to content

Commit

Permalink
Refactor UnwrapCastInComparison to use rewrite()
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Apr 15, 2024
1 parent a165b7f commit 684fa0a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
13 changes: 9 additions & 4 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,16 @@ pub fn unalias(expr: Expr) -> Expr {
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
pub fn rewrite_preserving_name<R>(
expr: Expr,
rewriter: &mut R,
) -> Result<Transformed<Expr>>
where
R: TreeNodeRewriter<Node = Expr>,
{
let original_name = expr.name_for_alias()?;
let expr = expr.rewrite(rewriter)?.data;
expr.alias_if_changed(original_name)
expr.rewrite(rewriter)?
.map_data(|expr| expr.alias_if_changed(original_name))
}

#[cfg(test)]
Expand Down Expand Up @@ -478,7 +481,9 @@ mod test {
let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter)
.data()
.unwrap();

let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue,
Expand Down Expand Up @@ -109,7 +109,7 @@ fn analyze_internal(
.map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
rewrite_preserving_name(expr, &mut expr_rewrite).data()
})
.collect::<Result<Vec<_>>>()?;

Expand Down
41 changes: 23 additions & 18 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,32 @@ impl UnwrapCastInComparison {
impl OptimizerRule for UnwrapCastInComparison {
fn try_optimize(
&self,
plan: &LogicalPlan,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called UnwrapCastInComparison::rewrite")
}

fn name(&self) -> &str {
"unwrap_cast_in_comparison"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
Expand All @@ -104,22 +124,7 @@ impl OptimizerRule for UnwrapCastInComparison {
schema: Arc::new(schema),
};

let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;

let inputs = plan.inputs().into_iter().cloned().collect();
plan.with_new_exprs(new_exprs, inputs).map(Some)
}

fn name(&self) -> &str {
"unwrap_cast_in_comparison"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
plan.map_expressions(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
}
}

Expand Down

0 comments on commit 684fa0a

Please sign in to comment.