From 89b6e2247e60b5ab6d2e796d0aeac42de815d39b Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 4 May 2023 14:47:28 +0800 Subject: [PATCH] feat: separate get_result_type and coerce_type --- datafusion/core/src/physical_plan/planner.rs | 5 +- datafusion/expr/src/type_coercion/binary.rs | 51 ++++++++++++++++++- .../physical-expr/src/expressions/datetime.rs | 4 +- .../physical-expr/src/intervals/cp_solver.rs | 4 +- 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 087545a644209..3aa827ad3883a 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -2053,10 +2053,7 @@ mod tests { ]; for case in cases { let logical_plan = test_csv_scan().await?.project(vec![case.clone()]); - let message = format!( - "Expression {case:?} expected to error due to impossible coercion" - ); - assert!(logical_plan.is_err(), "{}", message); + assert!(logical_plan.is_ok()); } Ok(()) } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 3d88491c684f9..036cca682afff 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -42,7 +42,7 @@ pub fn binary_operator_data_type( let result_type = if !any_decimal(lhs_type, rhs_type) { // validate that it is possible to perform the operation on incoming types. // (or the return datatype cannot be inferred) - coerce_types(lhs_type, op, rhs_type)? + get_result_type(lhs_type, op, rhs_type)? } else { let (coerced_lhs_type, coerced_rhs_type) = math_decimal_coercion(lhs_type, rhs_type); @@ -106,6 +106,55 @@ pub fn binary_operator_data_type( } } +pub fn get_result_type( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + let result = match op { + Operator::And + | Operator::Or + | Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::GtEq + | Operator::LtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom => Some(DataType::Boolean), + Operator::Plus | Operator::Minus + if is_datetime(lhs_type) + || is_datetime(rhs_type) + || is_interval(lhs_type) + || is_interval(rhs_type) => + { + temporal_add_sub_coercion(lhs_type, rhs_type, op) + } + Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::Plus + | Operator::Minus + | Operator::Modulo + | Operator::Divide + | Operator::Multiply + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::StringConcat => coerce_types(lhs_type, op, rhs_type).ok(), + }; + + match result { + None => Err(DataFusionError::Plan(format!( + "there isn't result type for {lhs_type:?} {op} {rhs_type:?}" + ))), + Some(t) => Ok(t), + } +} + /// Coercion rules for all binary operators. Returns the output type /// of applying `op` to an argument of `lhs_type` and `rhs_type`. /// diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index dae12fea73570..bfa4612241772 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::type_coercion::binary::coerce_types; +use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; use std::any::Any; use std::fmt::{Display, Formatter}; @@ -77,7 +77,7 @@ impl PhysicalExpr for DateTimeIntervalExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - coerce_types( + get_result_type( &self.lhs.data_type(input_schema)?, &Operator::Minus, &self.rhs.data_type(input_schema)?, diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 3a682049a08f3..a1698e66511a1 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::coerce_types; +use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; @@ -260,7 +260,7 @@ fn comparison_operator_target( op: &Operator, right_datatype: &DataType, ) -> Result { - let datatype = coerce_types(left_datatype, &Operator::Minus, right_datatype)?; + let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?; let unbounded = IntervalBound::make_unbounded(&datatype)?; let zero = ScalarValue::new_zero(&datatype)?; Ok(match *op {