From 2e9beeba01b85afb6d4f6557201e673008ea9edd Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 9 May 2023 00:04:33 +0800 Subject: [PATCH] minor: remove prefix in type_coercion (#6283) --- datafusion/expr/src/type_coercion/binary.rs | 97 +++++++++---------- .../optimizer/src/analyzer/type_coercion.rs | 5 +- .../physical-expr/src/expressions/binary.rs | 12 +-- 3 files changed, 53 insertions(+), 61 deletions(-) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index f01b06e0db06..242d609cbd9c 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -316,10 +316,10 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8, _) if DataType::is_numeric(rhs_type) => Some(Utf8), - (LargeUtf8, _) if DataType::is_numeric(rhs_type) => Some(LargeUtf8), - (_, Utf8) if DataType::is_numeric(lhs_type) => Some(Utf8), - (_, LargeUtf8) if DataType::is_numeric(lhs_type) => Some(LargeUtf8), + (Utf8, _) if is_numeric(rhs_type) => Some(Utf8), + (LargeUtf8, _) if is_numeric(rhs_type) => Some(LargeUtf8), + (_, Utf8) if is_numeric(lhs_type) => Some(Utf8), + (_, LargeUtf8) if is_numeric(lhs_type) => Some(LargeUtf8), _ => None, } } @@ -344,7 +344,9 @@ fn comparison_binary_numeric_coercion( // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { // support decimal data type for comparison operation - (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), + (Decimal128(_, _), Decimal128(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), @@ -390,23 +392,22 @@ fn get_comparison_common_decimal_type( decimal_type: &DataType, other_type: &DataType, ) -> Option { + use arrow::datatypes::DataType::*; let other_decimal_type = &match other_type { // This conversion rule is from spark // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 - DataType::Int8 => DataType::Decimal128(3, 0), - DataType::Int16 => DataType::Decimal128(5, 0), - DataType::Int32 => DataType::Decimal128(10, 0), - DataType::Int64 => DataType::Decimal128(20, 0), - DataType::Float32 => DataType::Decimal128(14, 7), - DataType::Float64 => DataType::Decimal128(30, 15), + Int8 => Decimal128(3, 0), + Int16 => Decimal128(5, 0), + Int32 => Decimal128(10, 0), + Int64 => Decimal128(20, 0), + Float32 => Decimal128(14, 7), + Float64 => Decimal128(30, 15), _ => { return None; } }; match (decimal_type, &other_decimal_type) { - (d1 @ DataType::Decimal128(_, _), d2 @ DataType::Decimal128(_, _)) => { - get_wider_decimal_type(d1, d2) - } + (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), _ => None, } } @@ -433,14 +434,15 @@ fn get_wider_decimal_type( /// Convert the numeric data type to the decimal data type. /// Now, we just support the signed integer type and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; match numeric_type { - DataType::Int8 => Some(DataType::Decimal128(3, 0)), - DataType::Int16 => Some(DataType::Decimal128(5, 0)), - DataType::Int32 => Some(DataType::Decimal128(10, 0)), - DataType::Int64 => Some(DataType::Decimal128(20, 0)), + Int8 => Some(Decimal128(3, 0)), + Int16 => Some(Decimal128(5, 0)), + Int32 => Some(Decimal128(10, 0)), + Int64 => Some(Decimal128(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. - DataType::Float32 => Some(DataType::Decimal128(14, 7)), - DataType::Float64 => Some(DataType::Decimal128(30, 15)), + Float32 => Some(Decimal128(14, 7)), + Float64 => Some(Decimal128(30, 15)), _ => None, } } @@ -605,48 +607,38 @@ pub fn decimal_op_mathematics_type( /// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { + use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (_, DataType::Null) => is_numeric(lhs_type), - (DataType::Null, _) => is_numeric(rhs_type), - ( - DataType::Dictionary(_, lhs_value_type), - DataType::Dictionary(_, rhs_value_type), - ) => is_numeric(lhs_value_type) && is_numeric(rhs_value_type), - (DataType::Dictionary(_, value_type), _) => { - is_numeric(value_type) && is_numeric(rhs_type) - } - (_, DataType::Dictionary(_, value_type)) => { - is_numeric(lhs_type) && is_numeric(value_type) + (_, Null) => is_numeric(lhs_type), + (Null, _) => is_numeric(rhs_type), + (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { + is_numeric(lhs_value_type) && is_numeric(rhs_value_type) } + (Dictionary(_, value_type), _) => is_numeric(value_type) && is_numeric(rhs_type), + (_, Dictionary(_, value_type)) => is_numeric(lhs_type) && is_numeric(value_type), _ => is_numeric(lhs_type) && is_numeric(rhs_type), } } /// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { + use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (_, DataType::Null) => is_decimal(lhs_type), - (DataType::Null, _) => is_decimal(rhs_type), - (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => true, - (DataType::Dictionary(_, value_type), _) => { - is_decimal(value_type) && is_decimal(rhs_type) - } - (_, DataType::Dictionary(_, value_type)) => { - is_decimal(lhs_type) && is_decimal(value_type) - } + (_, Null) => is_decimal(lhs_type), + (Null, _) => is_decimal(rhs_type), + (Decimal128(_, _), Decimal128(_, _)) => true, + (Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type), + (_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type), _ => false, } } /// Determine if at least of one of lhs and rhs is decimal pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { + use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (DataType::Dictionary(_, value_type), _) => { - is_decimal(value_type) || is_decimal(rhs_type) - } - (_, DataType::Dictionary(_, value_type)) => { - is_decimal(lhs_type) || is_decimal(value_type) - } + (Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type), + (_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type), (_, _) => is_decimal(lhs_type) || is_decimal(rhs_type), } } @@ -661,21 +653,22 @@ fn dictionary_coercion( rhs_type: &DataType, preserve_dictionaries: bool, ) -> Option { + use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { ( - DataType::Dictionary(_lhs_index_type, lhs_value_type), - DataType::Dictionary(_rhs_index_type, rhs_value_type), + Dictionary(_lhs_index_type, lhs_value_type), + Dictionary(_rhs_index_type, rhs_value_type), ) => comparison_coercion(lhs_value_type, rhs_value_type), - (d @ DataType::Dictionary(_, value_type), other_type) - | (other_type, d @ DataType::Dictionary(_, value_type)) + (d @ Dictionary(_, value_type), other_type) + | (other_type, d @ Dictionary(_, value_type)) if preserve_dictionaries && value_type.as_ref() == other_type => { Some(d.clone()) } - (DataType::Dictionary(_index_type, value_type), _) => { + (Dictionary(_index_type, value_type), _) => { comparison_coercion(value_type, rhs_type) } - (_, DataType::Dictionary(_index_type, value_type)) => { + (_, Dictionary(_index_type, value_type)) => { comparison_coercion(lhs_type, value_type) } _ => None, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a7e45c9fd231..759275dccc49 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -562,9 +562,8 @@ fn coerce_window_frame( // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { let left_type = expr.get_type(schema)?; - let right_type = DataType::Boolean; - let coerced_type = coerce_types(&left_type, &Operator::IsDistinctFrom, &right_type)?; - expr.clone().cast_to(&coerced_type, schema) + coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; + expr.clone().cast_to(&DataType::Boolean, schema) } /// Returns `expressions` coerced to types compatible with diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d99c147202c0..9b46d792589c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -2208,10 +2208,10 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); - let result_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; - let left = try_cast(col("a", &schema)?, &schema, result_type.clone())?; - let right = try_cast(col("b", &schema)?, &schema, result_type)?; + let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?; + let right = try_cast(col("b", &schema)?, &schema, common_type)?; // verify that we can construct the expression let expression = binary(left, $OP, right, &schema)?; @@ -3687,10 +3687,10 @@ mod tests { ) -> Result<()> { let left_type = left.data_type(); let right_type = right.data_type(); - let result_type = coerce_types(left_type, &op, right_type)?; + let common_type = coerce_types(left_type, &op, right_type)?; - let left_expr = try_cast(col("a", schema)?, schema, result_type.clone())?; - let right_expr = try_cast(col("b", schema)?, schema, result_type)?; + let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?; + let right_expr = try_cast(col("b", schema)?, schema, common_type)?; let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?;