From b5c23c2c29f0e8ffc0f7021c78274176cf2cacb0 Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 24 Oct 2022 21:01:56 +0800 Subject: [PATCH] Refactor Expr::Cast to use a struct. (#3931) * Refactor Expr::Cast to use a struct. * fix * fix fmt * fix review --- benchmarks/src/bin/tpch.rs | 40 ++++++------- .../core/src/physical_optimizer/pruning.rs | 13 ++-- datafusion/core/src/physical_plan/planner.rs | 51 ++++++++-------- .../core/tests/provider_filter_pushdown.rs | 10 ++-- datafusion/expr/src/expr.rs | 32 ++++++---- datafusion/expr/src/expr_fn.rs | 7 +-- datafusion/expr/src/expr_rewriter.rs | 11 ++-- datafusion/expr/src/expr_schema.rs | 11 ++-- datafusion/expr/src/expr_visitor.rs | 3 +- .../optimizer/src/projection_push_down.rs | 16 ++--- datafusion/optimizer/src/reduce_outer_join.rs | 18 +++--- .../optimizer/src/simplify_expressions.rs | 27 +++------ .../src/unwrap_cast_in_comparison.rs | 10 ++-- datafusion/optimizer/src/utils.rs | 18 ++---- datafusion/physical-expr/src/planner.rs | 3 +- datafusion/proto/src/from_proto.rs | 4 +- datafusion/proto/src/lib.rs | 7 +-- datafusion/proto/src/to_proto.rs | 59 ++++++++++--------- datafusion/sql/src/planner.rs | 57 +++++++++--------- datafusion/sql/src/utils.rs | 12 ++-- 20 files changed, 200 insertions(+), 209 deletions(-) diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index b2f2bf181787..7a914b19b9e1 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -614,8 +614,8 @@ mod tests { use datafusion::arrow::array::*; use datafusion::arrow::util::display::array_value_to_string; + use datafusion::logical_expr::expr::Cast; use datafusion::logical_expr::Expr; - use datafusion::logical_expr::Expr::Cast; use datafusion::logical_expr::Expr::ScalarFunction; use datafusion::sql::TableReference; @@ -798,9 +798,9 @@ mod tests { let path = Path::new(&path); if let Ok(expected) = read_text_file(path) { assert_eq!(expected, actual, - // generate output that is easier to copy/paste/update - "\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n", - path, expected, actual); + // generate output that is easier to copy/paste/update + "\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n", + path, expected, actual); found = true; break; } @@ -1264,10 +1264,10 @@ mod tests { args: vec![col(Field::name(field)).mul(lit(100))], }.div(lit(100))); Expr::Alias( - Box::new(Cast { - expr: round, - data_type: DataType::Decimal128(38, 2), - }), + Box::new(Expr::Cast(Cast::new( + round, + DataType::Decimal128(38, 2), + ))), Field::name(field).to_string(), ) } @@ -1343,23 +1343,23 @@ mod tests { DataType::Decimal128(_, _) => { // there's no support for casting from Utf8 to Decimal, so // we'll cast from Utf8 to Float64 to Decimal for Decimal types - let inner_cast = Box::new(Cast { - expr: Box::new(trim(col(Field::name(field)))), - data_type: DataType::Float64, - }); + let inner_cast = Box::new(Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + DataType::Float64, + ))); Expr::Alias( - Box::new(Cast { - expr: inner_cast, - data_type: Field::data_type(field).to_owned(), - }), + Box::new(Expr::Cast(Cast::new( + inner_cast, + Field::data_type(field).to_owned(), + ))), Field::name(field).to_string(), ) } _ => Expr::Alias( - Box::new(Cast { - expr: Box::new(trim(col(Field::name(field)))), - data_type: Field::data_type(field).to_owned(), - }), + Box::new(Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + Field::data_type(field).to_owned(), + ))), Field::name(field).to_string(), ), } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 0b309412c0c6..59ecdaba3d5a 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -46,7 +46,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_expr::expr::BinaryExpr; +use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable}; @@ -190,11 +190,10 @@ impl PruningPredicate { let predicate_array = downcast_value!(array, BooleanArray); Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - - }, + .into_iter() + .map(|x| x.unwrap_or(true)) // None -> true per comments above + .collect::>()) + } // result was a column ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { let v = v.unwrap_or(true); // None -> true per comments above @@ -530,7 +529,7 @@ fn rewrite_expr_to_prunable( // `col op lit()` Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())), // `cast(col) op lit()` - Expr::Cast { expr, data_type } => { + Expr::Cast(Cast { expr, data_type }) => { let from_type = expr.get_type(&schema)?; verify_support_type_for_prune(&from_type, data_type)?; let (left, op, right) = diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 1995a6196eed..4a8399af920c 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -59,7 +59,9 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::expr::{Between, BinaryExpr, GetIndexedField, GroupingSet, Like}; +use datafusion_expr::expr::{ + Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, +}; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::utils::{expand_wildcard, expr_to_columns}; use datafusion_expr::WindowFrameUnits; @@ -126,7 +128,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { name += "END"; Ok(name) } - Expr::Cast { expr, .. } => { + Expr::Cast(Cast { expr, .. }) => { // CAST does not change the expression name create_physical_name(expr, false) } @@ -462,7 +464,7 @@ impl DefaultPhysicalPlanner { ) -> BoxFuture<'a, Result>> { async move { let exec_plan: Result> = match logical_plan { - LogicalPlan::TableScan (TableScan { + LogicalPlan::TableScan(TableScan { source, projection, filters, @@ -484,7 +486,7 @@ impl DefaultPhysicalPlanner { let exec_schema = schema.as_ref().to_owned().into(); let exprs = values.iter() .map(|row| { - row.iter().map(|expr|{ + row.iter().map(|expr| { self.create_physical_expr( expr, schema, @@ -497,7 +499,7 @@ impl DefaultPhysicalPlanner { .collect::>>()?; let value_exec = ValuesExec::try_new( SchemaRef::new(exec_schema), - exprs + exprs, )?; Ok(Arc::new(value_exec)) } @@ -612,7 +614,7 @@ impl DefaultPhysicalPlanner { window_expr, input_exec, physical_input_schema, - )?) ) + )?)) } LogicalPlan::Aggregate(Aggregate { input, @@ -692,16 +694,16 @@ impl DefaultPhysicalPlanner { aggregates, initial_aggr, physical_input_schema.clone(), - )?) ) + )?)) } - LogicalPlan::Distinct(Distinct {input}) => { + LogicalPlan::Distinct(Distinct { input }) => { // Convert distinct to groupby with no aggregations let group_expr = expand_wildcard(input.schema(), input)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - input.clone(), - group_expr, - vec![], - input.schema().clone() // input schema and aggregate schema are the same in this case + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + input.clone(), + group_expr, + vec![], + input.schema().clone(), // input schema and aggregate schema are the same in this case )?); Ok(self.create_initial_plan(&aggregate, session_state).await?) } @@ -755,7 +757,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(ProjectionExec::try_new( physical_exprs, input_exec, - )?) ) + )?)) } LogicalPlan::Filter(filter) => { let physical_input = self.create_initial_plan(filter.input(), session_state).await?; @@ -768,14 +770,14 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?) ) + Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = futures::stream::iter(inputs) .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; - Ok(Arc::new(UnionExec::new(physical_plans)) ) + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, @@ -803,13 +805,13 @@ impl DefaultPhysicalPlanner { Partitioning::Hash(runtime_expr, *n) } LogicalPartitioning::DistributeBy(_) => { - return Err(DataFusionError::NotImplemented("Physical plan does not support DistributeBy partitioning".to_string())) + return Err(DataFusionError::NotImplemented("Physical plan does not support DistributeBy partitioning".to_string())); } }; Ok(Arc::new(RepartitionExec::try_new( physical_input, physical_partitioning, - )?) ) + )?)) } LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { let physical_input = self.create_initial_plan(input, session_state).await?; @@ -852,7 +854,8 @@ impl DefaultPhysicalPlanner { Arc::new(merge) } else { Arc::new(SortExec::try_new(sort_expr, physical_input, *fetch)?) - }) } + }) + } LogicalPlan::Join(Join { left, right, @@ -922,14 +925,14 @@ impl DefaultPhysicalPlanner { expr, &filter_df_schema, &filter_schema, - &session_state.execution_props + &session_state.execution_props, )?; let column_indices = join_utils::JoinFilter::build_column_indices(left_field_indices, right_field_indices); Some(join_utils::JoinFilter::new( filter_expr, column_indices, - filter_schema + filter_schema, )) } _ => None @@ -995,7 +998,7 @@ impl DefaultPhysicalPlanner { *produce_one_row, SchemaRef::new(schema.as_ref().to_owned().into()), ))), - LogicalPlan::SubqueryAlias(SubqueryAlias { input,.. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { match input.as_ref() { LogicalPlan::TableScan(..) => { self.create_initial_plan(input, session_state).await @@ -1003,7 +1006,7 @@ impl DefaultPhysicalPlanner { _ => Err(DataFusionError::Plan("SubqueryAlias should only wrap TableScan".to_string())) } } - LogicalPlan::Limit(Limit { input, skip, fetch,.. }) => { + LogicalPlan::Limit(Limit { input, skip, fetch, .. }) => { let input = self.create_initial_plan(input, session_state).await?; // GlobalLimitExec requires a single partition for input @@ -1055,7 +1058,7 @@ impl DefaultPhysicalPlanner { SchemaRef::new(Schema::empty()), ))) } - LogicalPlan::Explain (_) => Err(DataFusionError::Internal( + LogicalPlan::Explain(_) => Err(DataFusionError::Internal( "Unsupported logical plan: Explain must be root of the plan".to_string(), )), LogicalPlan::Analyze(a) => { diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 672f7ee1a3ff..c1aa5ad7095c 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -32,7 +32,7 @@ use datafusion::physical_plan::{ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_common::DataFusionError; -use datafusion_expr::expr::BinaryExpr; +use datafusion_expr::expr::{BinaryExpr, Cast}; use std::ops::Deref; use std::sync::Arc; @@ -153,7 +153,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64, - Expr::Cast { expr, data_type: _ } => match expr.deref() { + Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { Expr::Literal(lit_value) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, @@ -163,21 +163,21 @@ impl TableProvider for CustomProvider { return Err(DataFusionError::NotImplemented(format!( "Do not support value {:?}", other_value - ))) + ))); } }, other_expr => { return Err(DataFusionError::NotImplemented(format!( "Do not support expr {:?}", other_expr - ))) + ))); } }, other_expr => { return Err(DataFusionError::NotImplemented(format!( "Do not support expr {:?}", other_expr - ))) + ))); } }; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 1d11245c3881..0108f92d3ae0 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -143,12 +143,7 @@ pub enum Expr { Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. - Cast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, + Cast(Cast), /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast { @@ -360,6 +355,22 @@ impl GetIndexedField { } } +/// Cast expression +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Cast { + /// The expression being cast + pub expr: Box, + /// The `DataType` the expression will yield + pub data_type: DataType, +} + +impl Cast { + /// Create a new Cast expression + pub fn new(expr: Box, data_type: DataType) -> Self { + Self { expr, data_type } + } +} + /// Grouping sets /// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS /// for Postgres definition. @@ -682,7 +693,7 @@ impl fmt::Debug for Expr { } write!(f, "END") } - Expr::Cast { expr, data_type } => { + Expr::Cast(Cast { expr, data_type }) => { write!(f, "CAST({:?} AS {:?})", expr, data_type) } Expr::TryCast { expr, data_type } => { @@ -1038,7 +1049,7 @@ fn create_name(e: &Expr) -> Result { name += "END"; Ok(name) } - Expr::Cast { expr, .. } => { + Expr::Cast(Cast { expr, .. }) => { // CAST does not change the expression name create_name(expr) } @@ -1212,6 +1223,7 @@ fn create_names(exprs: &[Expr]) -> Result { #[cfg(test)] mod test { + use crate::expr::Cast; use crate::expr_fn::col; use crate::{case, lit, Expr}; use arrow::datatypes::DataType; @@ -1233,10 +1245,10 @@ mod test { #[test] fn format_cast() -> Result<()> { - let expr = Expr::Cast { + let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), data_type: DataType::Utf8, - }; + }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{}", expr)); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9d8b05f2b155..0685ec835140 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -17,7 +17,7 @@ //! Functions for creating logical expressions -use crate::expr::{BinaryExpr, GroupingSet}; +use crate::expr::{BinaryExpr, Cast, GroupingSet}; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF, @@ -259,10 +259,7 @@ pub fn rollup(exprs: Vec) -> Expr { /// Create a cast expression pub fn cast(expr: Expr, data_type: DataType) -> Expr { - Expr::Cast { - expr: Box::new(expr), - data_type, - } + Expr::Cast(Cast::new(Box::new(expr), data_type)) } /// Create a try cast expression diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 89cadc5ee8c4..8ac645617a45 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -17,7 +17,7 @@ //! Expression rewriter -use crate::expr::{Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like}; +use crate::expr::{Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, Like}; use crate::logical_plan::{Aggregate, Projection}; use crate::utils::{from_plan, grouping_set_to_exprlist}; use crate::{Expr, ExprSchemable, LogicalPlan}; @@ -203,10 +203,9 @@ impl ExprRewritable for Expr { Expr::Case(Case::new(expr, when_then_expr, else_expr)) } - Expr::Cast { expr, data_type } => Expr::Cast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, + Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast::new(rewrite_boxed(expr, rewriter)?, data_type)) + } Expr::TryCast { expr, data_type } => Expr::TryCast { expr: rewrite_boxed(expr, rewriter)?, data_type, @@ -566,6 +565,7 @@ mod test { struct RecordingRewriter { v: Vec, } + impl ExprRewriter for RecordingRewriter { fn mutate(&mut self, expr: Expr) -> Result { self.v.push(format!("Mutated {:?}", expr)); @@ -593,6 +593,7 @@ mod test { /// rewrites all "foo" string literals to "bar" struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { fn mutate(&mut self, expr: Expr) -> Result { match expr { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e19f6a9fcb21..8424fa2aa2d1 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,7 +16,7 @@ // under the License. use super::{Between, Expr, Like}; -use crate::expr::{BinaryExpr, GetIndexedField}; +use crate::expr::{BinaryExpr, Cast, GetIndexedField}; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; use crate::{aggregate_function, function, window_function}; @@ -61,7 +61,7 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), - Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { + Expr::Cast(Cast { data_type, .. }) | Expr::TryCast { data_type, .. } => { Ok(data_type.clone()) } Expr::ScalarUDF { fun, args } => { @@ -182,7 +182,7 @@ impl ExprSchemable for Expr { Ok(true) } } - Expr::Cast { expr, .. } => expr.nullable(input_schema), + Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction { .. } @@ -262,10 +262,7 @@ impl ExprSchemable for Expr { if this_type == *cast_to_type { Ok(self) } else if can_cast_types(&this_type, cast_to_type) { - Ok(Expr::Cast { - expr: Box::new(self), - data_type: cast_to_type.clone(), - }) + Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))) } else { Err(DataFusionError::Plan(format!( "Cannot automatically convert {:?} to {:?}", diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index 4014847bc2aa..bd839f098fc3 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -17,6 +17,7 @@ //! Expression visitor +use crate::expr::Cast; use crate::{ expr::{BinaryExpr, GroupingSet}, Between, Expr, GetIndexedField, Like, @@ -108,7 +109,7 @@ impl ExprVisitable for Expr { | Expr::IsNotUnknown(expr) | Expr::IsNull(expr) | Expr::Negative(expr) - | Expr::Cast { expr, .. } + | Expr::Cast(Cast { expr, .. }) | Expr::TryCast { expr, .. } | Expr::Sort { expr, .. } | Expr::InSubquery { expr, .. } => expr.accept(visitor), diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs index d6ed6e4884bc..b2e776821fb1 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/projection_push_down.rs @@ -534,10 +534,10 @@ fn projection_equal(p: &Projection, p2: &Projection) -> bool { #[cfg(test)] mod tests { - use super::*; use crate::test::*; use arrow::datatypes::DataType; + use datafusion_expr::expr::Cast; use datafusion_expr::{ col, count, lit, logical_plan::{builder::LogicalPlanBuilder, JoinType}, @@ -699,7 +699,7 @@ mod tests { DFField::new(Some("test"), "b", DataType::UInt32, false), DFField::new(Some("test2"), "c1", DataType::UInt32, false), ], - HashMap::new() + HashMap::new(), )?, ); @@ -742,7 +742,7 @@ mod tests { DFField::new(Some("test"), "b", DataType::UInt32, false), DFField::new(Some("test2"), "c1", DataType::UInt32, false), ], - HashMap::new() + HashMap::new(), )?, ); @@ -783,7 +783,7 @@ mod tests { DFField::new(Some("test"), "b", DataType::UInt32, false), DFField::new(Some("test2"), "a", DataType::UInt32, false), ], - HashMap::new() + HashMap::new(), )?, ); @@ -795,10 +795,10 @@ mod tests { let table_scan = test_table_scan()?; let projection = LogicalPlanBuilder::from(table_scan) - .project(vec![Expr::Cast { - expr: Box::new(col("c")), - data_type: DataType::Float64, - }])? + .project(vec![Expr::Cast(Cast::new( + Box::new(col("c")), + DataType::Float64, + ))])? .build()?; let expected = "Projection: CAST(test.c AS Float64)\ diff --git a/datafusion/optimizer/src/reduce_outer_join.rs b/datafusion/optimizer/src/reduce_outer_join.rs index d0016cc1914b..8b550e11386b 100644 --- a/datafusion/optimizer/src/reduce_outer_join.rs +++ b/datafusion/optimizer/src/reduce_outer_join.rs @@ -25,6 +25,7 @@ use datafusion_expr::{ }; use datafusion_expr::{Expr, Operator}; +use datafusion_expr::expr::Cast; use std::collections::HashMap; use std::sync::Arc; @@ -351,15 +352,14 @@ fn extract_nonnullable_columns( false, ) } - Expr::Cast { expr, data_type: _ } | Expr::TryCast { expr, data_type: _ } => { - extract_nonnullable_columns( - expr, - nonnullable_cols, - left_schema, - right_schema, - false, - ) - } + Expr::Cast(Cast { expr, data_type: _ }) + | Expr::TryCast { expr, data_type: _ } => extract_nonnullable_columns( + expr, + nonnullable_cols, + left_schema, + right_schema, + false, + ), _ => Ok(()), } } diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index bc3ac6e84b23..32c8c9bce856 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -278,9 +278,9 @@ fn simpl_concat(args: Vec) -> Result { ) => contiguous_scalar += &v, Expr::Literal(x) => { return Err(DataFusionError::Internal(format!( - "The scalar {} should be casted to string type during the type coercion.", - x - ))) + "The scalar {} should be casted to string type during the type coercion.", + x + ))); } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. @@ -921,7 +921,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { op: Divide, right, }) if !info.nullable(&left)? && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)) + return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } // @@ -952,7 +952,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { op: Modulo, right, }) if !info.nullable(&left)? && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)) + return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } // @@ -1079,7 +1079,7 @@ mod tests { use arrow::array::{ArrayRef, Int32Array}; use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::{DFField, DFSchemaRef}; - use datafusion_expr::expr::Case; + use datafusion_expr::expr::{Case, Cast}; use datafusion_expr::expr_fn::{concat, concat_ws}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1663,10 +1663,7 @@ mod tests { } fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast { - expr: expr.into(), - data_type: DataType::Int64, - } + Expr::Cast(Cast::new(expr.into(), DataType::Int64)) } fn to_timestamp_expr(arg: impl Into) -> Expr { @@ -2425,10 +2422,7 @@ mod tests { #[test] fn cast_expr() { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast { - expr: Box::new(lit("0")), - data_type: DataType::Int32, - }]; + let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() @@ -2444,10 +2438,7 @@ mod tests { #[test] fn cast_expr_wrong_arg() { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast { - expr: Box::new(lit("")), - data_type: DataType::Int32, - }]; + let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), DataType::Int32))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 93b8b71d6ebb..3dfbaa028187 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::BinaryExpr; +use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ @@ -132,7 +132,7 @@ impl ExprRewriter for UnwrapCastExprRewriter { match (&left, &right) { ( Expr::Literal(left_lit_value), - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, + Expr::TryCast { expr, .. } | Expr::Cast(Cast { expr, .. }), ) => { // if the left_lit_value can be casted to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal @@ -149,7 +149,7 @@ impl ExprRewriter for UnwrapCastExprRewriter { } } ( - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, + Expr::TryCast { expr, .. } | Expr::Cast(Cast { expr, .. }), Expr::Literal(right_lit_value), ) => { // if the right_lit_value can be casted to the type of expr @@ -186,10 +186,10 @@ impl ExprRewriter for UnwrapCastExprRewriter { expr: internal_left_expr, .. } - | Expr::Cast { + | Expr::Cast(Cast { expr: internal_left_expr, .. - }, + }), ) = Some(left_expr.as_ref()) { let internal_left = internal_left_expr.as_ref().clone(); diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 130df3e0e6ef..4eda6e3e3cc2 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -469,6 +469,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; + use datafusion_expr::expr::Cast; use datafusion_expr::{col, lit, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -528,7 +529,7 @@ mod tests { vec![ col("a").eq(lit(5)), // no alias on b - col("b") + col("b"), ] ); } @@ -577,17 +578,11 @@ mod tests { fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( - &Expr::Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }, + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; expr_to_columns( - &Expr::Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }, + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; assert_eq!(1, accum.len()); @@ -604,10 +599,7 @@ mod tests { // cast data types test_rewrite( col("a"), - Expr::Cast { - expr: Box::new(col("a")), - data_type: DataType::Int32, - }, + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), ); // change literal type from i32 to i64 diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8080c8f3066a..7332b0910cdf 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -27,6 +27,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema}; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::Cast; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, }; @@ -277,7 +278,7 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast { expr, data_type } => expressions::cast( + Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, input_schema, data_type.clone(), diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 79b477b3e1ef..7b2210eca6a1 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -31,7 +31,7 @@ use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue, }; -use datafusion_expr::expr::BinaryExpr; +use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_expr::{ abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin, @@ -967,7 +967,7 @@ pub fn parse_expr( ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast { expr, data_type }) + Ok(Expr::Cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?); diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 7feae7965215..8dd1b55f5bae 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -63,7 +63,7 @@ mod roundtrip_tests { use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::create_udaf; - use datafusion_expr::expr::{Between, BinaryExpr, Case, GroupingSet, Like}; + use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, AggregateState, @@ -893,10 +893,7 @@ mod roundtrip_tests { #[test] fn roundtrip_cast() { - let test_expr = Expr::Cast { - expr: Box::new(lit(1.0_f32)), - data_type: DataType::Boolean, - }; + let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index f8dab779b405..9de52f29141c 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -34,7 +34,9 @@ use arrow::datatypes::{ UnionMode, }; use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue}; -use datafusion_expr::expr::{Between, BinaryExpr, GetIndexedField, GroupingSet, Like}; +use datafusion_expr::expr::{ + Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, +}; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound, @@ -130,8 +132,7 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { type Error = Error; fn try_from(val: &DataType) -> Result { - let res = - match val { + let res = match val { DataType::Null => Self::None(EmptyMessage {}), DataType::Boolean => Self::Bool(EmptyMessage {}), DataType::Int8 => Self::Int8(EmptyMessage {}), @@ -194,7 +195,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: union_types.iter().map(|field| field.try_into()).collect::, Error>>()?, + union_types: union_types + .iter() + .map(|field| field.try_into()) + .collect::, Error>>()?, union_mode: union_mode.into(), type_ids: type_ids.iter().map(|x| *x as i32).collect(), }) @@ -456,7 +460,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::BinaryExpr(binary_expr)), } } - Expr::Like(Like { negated, expr, pattern, escape_char} ) => { + Expr::Like(Like { negated, expr, pattern, escape_char }) => { let pb = Box::new(protobuf::LikeNode { negated: *negated, expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -469,7 +473,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::Like(pb)), } } - Expr::ILike(Like { negated, expr, pattern, escape_char} ) => { + Expr::ILike(Like { negated, expr, pattern, escape_char }) => { let pb = Box::new(protobuf::ILikeNode { negated: *negated, expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -482,7 +486,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::Ilike(pb)), } } - Expr::SimilarTo(Like { negated, expr, pattern, escape_char} ) => { + Expr::SimilarTo(Like { negated, expr, pattern, escape_char }) => { let pb = Box::new(protobuf::SimilarToNode { negated: *negated, expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -598,7 +602,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { filter: match filter { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, - } + }, }; Self { expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), @@ -637,16 +641,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { args: args.iter().map(|expr| expr.try_into()).collect::, Error, - >>( - )?, + >>()?, filter: match filter { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, - } + }, }, - ))), + ))), } - }, + } Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -760,7 +763,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast { expr, data_type } => { + Expr::Cast(Cast { expr, data_type }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(expr.as_ref().try_into()?)), arrow_type: Some(data_type.try_into()?), @@ -814,25 +817,24 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } not supported".to_string())) + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } not supported".to_string())); } - Expr::GetIndexedField(GetIndexedField{key, expr}) => + Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - key: Some(key.try_into()?), - expr: Some(Box::new(expr.as_ref().try_into()?)), - }, - ))), - }, + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + key: Some(key.try_into()?), + expr: Some(Box::new(expr.as_ref().try_into()?)), + }, + ))), + }, Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { expr_type: Some(ExprType::Cube(CubeNode { expr: exprs.iter().map(|expr| expr.try_into()).collect::, Self::Error, - >>( - )?, + >>()?, })), }, Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self { @@ -840,8 +842,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr: exprs.iter().map(|expr| expr.try_into()).collect::, Self::Error, - >>( - )?, + >>()?, })), }, Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self { @@ -861,7 +862,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } => - return Err(Error::General("Proto serialization error: Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } not supported".to_string())), + return Err(Error::General("Proto serialization error: Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } not supported".to_string())), }; Ok(expr_node) @@ -1066,7 +1067,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { scalar::ScalarValue::FixedSizeBinary(_, _) => { return Err(Error::General( "FixedSizeBinary is not yet implemented".to_owned(), - )) + )); } datafusion::scalar::ScalarValue::Time64(v) => { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 92264f06023f..7dccbc522f3e 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -51,7 +51,7 @@ use crate::utils::{make_decimal_type, normalize_ident, resolve_columns}; use datafusion_common::{ field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{Between, BinaryExpr, Case, GroupingSet, Like}; +use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::logical_plan::builder::project_with_alias; use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::Expr::Alias; @@ -105,7 +105,7 @@ fn plan_key(key: SQLExpr) -> Result { return Err(DataFusionError::SQL(ParserError(format!( "Unsuported index key expression: {:?}", key - )))) + )))); } }; @@ -1706,18 +1706,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &schema, &mut HashMap::new(), ), - SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast { - expr: Box::new(lit(value)), - data_type: convert_data_type(&data_type)?, - }), - SQLExpr::Cast { expr, data_type } => Ok(Expr::Cast { - expr: Box::new(self.sql_expr_to_logical_expr( + SQLExpr::TypedString { data_type, value } => { + Ok(Expr::Cast(Cast::new( + Box::new(lit(value)), + convert_data_type(&data_type)?, + ))) + } + SQLExpr::Cast { expr, data_type } => Ok(Expr::Cast(Cast::new( + Box::new(self.sql_expr_to_logical_expr( *expr, &schema, &mut HashMap::new(), )?), - data_type: convert_data_type(&data_type)?, - }), + convert_data_type(&data_type)?, + ))), other => Err(DataFusionError::NotImplemented(format!( "Unsupported value {:?} in a values list expression", other @@ -1823,14 +1825,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { match (var_names.pop(), var_names.pop()) { (Some(name), Some(relation)) if var_names.is_empty() => { - match schema.field_with_qualified_name(&relation, &name) { + match schema.field_with_qualified_name(&relation, &name) { Ok(_) => { // found an exact match on a qualified name so this is a table.column identifier Ok(Expr::Column(Column { relation: Some(relation), name, })) - }, + } Err(_) => { if let Some(field) = schema.fields().iter().find(|f| f.name().eq(&relation)) { // Access to a field of a column which is a structure, example: SELECT my_struct.key @@ -1895,10 +1897,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Cast { expr, data_type, - } => Ok(Expr::Cast { - expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), - data_type: convert_data_type(&data_type)?, - }), + } => Ok(Expr::Cast(Cast::new( + Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + convert_data_type(&data_type)?, + ))), SQLExpr::TryCast { expr, @@ -1911,10 +1913,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::TypedString { data_type, value, - } => Ok(Expr::Cast { - expr: Box::new(lit(value)), - data_type: convert_data_type(&data_type)?, - }), + } => Ok(Expr::Cast(Cast::new( + Box::new(lit(value)), + convert_data_type(&data_type)?, + ))), SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( self.sql_expr_to_logical_expr(*expr, schema, ctes)?, @@ -1991,7 +1993,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated, Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), Box::new(pattern), - escape_char + escape_char, ))) } @@ -2007,7 +2009,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated, Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), Box::new(pattern), - escape_char + escape_char, ))) } @@ -2119,10 +2121,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match self.sql_expr_to_logical_expr(*expr, schema, ctes)? { Expr::AggregateFunction { fun, args, distinct, .. - } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), + } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), _ => Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not an AggregateFunction".to_string())) } - } SQLExpr::Function(mut function) => { @@ -2221,7 +2222,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, distinct, args, - filter: None + filter: None, }); }; @@ -2247,9 +2248,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, ctes), - SQLExpr::Exists{ subquery, negated } => self.parse_exists_subquery(&subquery, negated, schema, ctes), + SQLExpr::Exists { subquery, negated } => self.parse_exists_subquery(&subquery, negated, schema, ctes), - SQLExpr::InSubquery { expr, subquery, negated } => self.parse_in_subquery(&expr, &subquery, negated, schema, ctes), + SQLExpr::InSubquery { expr, subquery, negated } => self.parse_in_subquery(&expr, &subquery, negated, schema, ctes), SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(&subquery, schema, ctes), @@ -2390,7 +2391,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Err(DataFusionError::NotImplemented(format!( "Unsupported interval argument. Expected string literal, got: {:?}", value - ))) + ))); } }; diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index d138897a84ed..a4176474c72b 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -21,6 +21,7 @@ use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE use sqlparser::ast::Ident; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::Cast; use datafusion_expr::expr::{ Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like, }; @@ -340,13 +341,10 @@ where Expr::IsNotUnknown(nested_expr) => Ok(Expr::IsNotUnknown(Box::new( clone_with_replacement(nested_expr, replacement_fn)?, ))), - Expr::Cast { - expr: nested_expr, - data_type, - } => Ok(Expr::Cast { - expr: Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - data_type: data_type.clone(), - }), + Expr::Cast(Cast { expr, data_type }) => Ok(Expr::Cast(Cast::new( + Box::new(clone_with_replacement(expr, replacement_fn)?), + data_type.clone(), + ))), Expr::TryCast { expr: nested_expr, data_type,