diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index fda390f37961..5ede43a05134 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -23,14 +23,14 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use crate::utils::NamePreserver; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, @@ -85,12 +85,32 @@ impl UnwrapCastInComparison { impl OptimizerRule for UnwrapCastInComparison { fn try_optimize( &self, - plan: &LogicalPlan, + _plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called UnwrapCastInComparison::rewrite") + } + + fn name(&self) -> &str { + "unwrap_cast_in_comparison" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -104,22 +124,12 @@ impl OptimizerRule for UnwrapCastInComparison { schema: Arc::new(schema), }; - let new_exprs = plan - .expressions() - .into_iter() - .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) - .collect::>>()?; - - let inputs = plan.inputs().into_iter().cloned().collect(); - plan.with_new_exprs(new_exprs, inputs).map(Some) - } - - fn name(&self) -> &str { - "unwrap_cast_in_comparison" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewriter)? + .map_data(|expr| original_name.restore(expr)) + }) } }