diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs index 9fc29d6ee28c..205ee2e23ca9 100644 --- a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -21,9 +21,10 @@ use crate::arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; use crate::physical_plan::expressions::coercion::{ - dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion, - string_coercion, temporal_coercion, + dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion, + temporal_coercion, }; +use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; /// Coercion rules for all binary operators. Returns the output type /// of applying `op` to an argument of `lhs_type` and `rhs_type`. @@ -49,12 +50,11 @@ pub(crate) fn coerce_types( Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type), // for math expressions, the final value of the coercion is also the return type // because coercion favours higher information types - // TODO: support decimal data type Operator::Plus | Operator::Minus | Operator::Modulo | Operator::Divide - | Operator::Multiply => numerical_coercion(lhs_type, rhs_type), + | Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type), Operator::RegexMatch | Operator::RegexIMatch | Operator::RegexNotMatch @@ -162,12 +162,141 @@ fn get_comparison_common_decimal_type( } } +// Convert the numeric data type to the decimal data type. +// Now, we just support the signed integer type and floating-point type. +fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { + match numeric_type { + DataType::Int8 => Some(DataType::Decimal(3, 0)), + DataType::Int16 => Some(DataType::Decimal(5, 0)), + DataType::Int32 => Some(DataType::Decimal(10, 0)), + DataType::Int64 => Some(DataType::Decimal(20, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + DataType::Float32 => Some(DataType::Decimal(14, 7)), + DataType::Float64 => Some(DataType::Decimal(30, 15)), + _ => None, + } +} + +fn mathematics_numerical_coercion( + mathematics_op: &Operator, + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return None; + }; + + // same type => all good + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Decimal(_, _), Decimal(_, _)) => { + coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type) + } + (Decimal(_, _), _) => { + let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); + match converted_decimal_type { + None => None, + Some(right_decimal_type) => coercion_decimal_mathematics_type( + mathematics_op, + lhs_type, + &right_decimal_type, + ), + } + } + (_, Decimal(_, _)) => { + let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type); + match converted_decimal_type { + None => None, + Some(left_decimal_type) => coercion_decimal_mathematics_type( + mathematics_op, + &left_decimal_type, + rhs_type, + ), + } + } + (Float64, _) | (_, Float64) => Some(Float64), + (_, Float32) | (Float32, _) => Some(Float32), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + _ => None, + } +} + +fn create_decimal_type(precision: usize, scale: usize) -> DataType { + DataType::Decimal( + MAX_PRECISION_FOR_DECIMAL128.min(precision), + MAX_SCALE_FOR_DECIMAL128.min(scale), + ) +} + +fn coercion_decimal_mathematics_type( + mathematics_op: &Operator, + left_decimal_type: &DataType, + right_decimal_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (left_decimal_type, right_decimal_type) { + // The coercion rule from spark + // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 + (Decimal(p1, s1), Decimal(p2, s2)) => { + match mathematics_op { + Operator::Plus | Operator::Minus => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Multiply => { + // s1 + s2 + let result_scale = *s1 + *s2; + // p1 + p2 + 1 + let result_precision = *p1 + *p2 + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Divide => { + // max(6, s1 + p2 + 1) + let result_scale = 6.max(*s1 + *p2 + 1); + // p1 - s1 + s2 + max(6, s1 + p2 + 1) + let result_precision = result_scale + *p1 - *s1 + *s2; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Modulo => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2-s2) + max(s1, s2) + let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2); + Some(create_decimal_type(result_precision, result_scale)) + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } +} + #[cfg(test)] mod tests { use crate::arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; - use crate::physical_plan::coercion_rule::binary_rule::coerce_types; + use crate::physical_plan::coercion_rule::binary_rule::{ + coerce_numeric_type_to_decimal, coerce_types, coercion_decimal_mathematics_type, + }; #[test] @@ -226,4 +355,70 @@ mod tests { assert!(result_type.is_err()); Ok(()) } + + #[test] + fn test_decimal_mathematics_op_type() { + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(), + DataType::Decimal(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(), + DataType::Decimal(5, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(), + DataType::Decimal(10, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), + DataType::Decimal(20, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), + DataType::Decimal(14, 7) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), + DataType::Decimal(30, 15) + ); + + let op = Operator::Plus; + let left_decimal_type = DataType::Decimal(10, 3); + let right_decimal_type = DataType::Decimal(20, 4); + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Minus; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Multiply; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(31, 7), result.unwrap()); + let op = Operator::Divide; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(35, 24), result.unwrap()); + let op = Operator::Modulo; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(11, 4), result.unwrap()); + } } diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 41715fe62a74..4a14c4c47b4d 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -41,6 +41,7 @@ use arrow::compute::kernels::comparison::{ regexp_is_match_utf8_scalar, }; use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; +use arrow::error::ArrowError::DivideByZero; use arrow::record_batch::RecordBatch; use crate::error::{DataFusionError, Result}; @@ -235,12 +236,10 @@ fn is_distinct_from_decimal( ) -> Result { let mut bool_builder = BooleanBuilder::new(left.len()); for i in 0..left.len() { - if left.is_null(i) && right.is_null(i) { - bool_builder.append_value(false)?; - } else if left.is_null(i) || right.is_null(i) { - bool_builder.append_value(true)?; - } else { - bool_builder.append_value(left.value(i) != right.value(i))?; + match (left.is_null(i), right.is_null(i)) { + (true, true) => bool_builder.append_value(false)?, + (true, false) | (false, true) => bool_builder.append_value(true)?, + (_, _) => bool_builder.append_value(left.value(i) != right.value(i))?, } } Ok(bool_builder.finish()) @@ -252,17 +251,89 @@ fn is_not_distinct_from_decimal( ) -> Result { let mut bool_builder = BooleanBuilder::new(left.len()); for i in 0..left.len() { - if left.is_null(i) && right.is_null(i) { - bool_builder.append_value(true)?; - } else if left.is_null(i) || right.is_null(i) { - bool_builder.append_value(false)?; - } else { - bool_builder.append_value(left.value(i) == right.value(i))?; + match (left.is_null(i), right.is_null(i)) { + (true, true) => bool_builder.append_value(true)?, + (true, false) | (false, true) => bool_builder.append_value(false)?, + (_, _) => bool_builder.append_value(left.value(i) == right.value(i))?, } } Ok(bool_builder.finish()) } +fn add_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) + right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + +fn subtract_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) - right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + +fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + let divide = 10_i128.pow(left.scale() as u32); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) * right.value(i) / divide)?; + } + } + Ok(decimal_builder.finish()) +} + +fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + let mul = 10_f64.powi(left.scale() as i32); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError(DivideByZero)); + } else { + let l_value = left.value(i) as f64; + let r_value = right.value(i) as f64; + let result = ((l_value / r_value) * mul) as i128; + decimal_builder.append_value(result)?; + } + } + Ok(decimal_builder.finish()) +} + +fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError(DivideByZero)); + } else { + decimal_builder.append_value(left.value(i) % right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + /// Binary expression #[derive(Debug)] pub struct BinaryExpr { @@ -472,6 +543,9 @@ macro_rules! binary_string_array_op { macro_rules! binary_primitive_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + // TODO support decimal type + // which is not the primitive type + DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), @@ -2549,6 +2623,7 @@ mod tests { .unwrap(); // is distinct: float64array is distinct decimal array // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion. + // traced by https://github.com/apache/arrow-datafusion/issues/1590 // the decimal array will be casted to float64array apply_logic_op( &schema, @@ -2570,4 +2645,197 @@ mod tests { Ok(()) } + + #[test] + fn arithmetic_decimal_op_test() -> Result<()> { + let value_i128: i128 = 123; + let left_decimal_array = create_decimal_array( + &[ + Some(value_i128), + None, + Some(value_i128 - 1), + Some(value_i128 + 1), + ], + 25, + 3, + )?; + let right_decimal_array = create_decimal_array( + &[ + Some(value_i128), + Some(value_i128), + Some(value_i128), + Some(value_i128), + ], + 25, + 3, + )?; + // add + let result = add_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = + create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3)?; + assert_eq!(expect, result); + // subtract + let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?; + assert_eq!(expect, result); + // multiply + let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?; + assert_eq!(expect, result); + // divide + let left_decimal_array = create_decimal_array( + &[Some(1234567), None, Some(1234567), Some(1234567)], + 25, + 3, + )?; + let right_decimal_array = + create_decimal_array(&[Some(10), Some(100), Some(55), Some(-123)], 25, 3)?; + let result = divide_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array( + &[Some(123456700), None, Some(22446672), Some(-10037130)], + 25, + 3, + )?; + assert_eq!(expect, result); + // modulus + let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16)], 25, 3)?; + assert_eq!(expect, result); + + Ok(()) + } + + fn apply_arithmetic_op( + schema: &SchemaRef, + left: &ArrayRef, + right: &ArrayRef, + op: Operator, + expected: ArrayRef, + ) -> Result<()> { + let arithmetic_op = + binary_simple(col("a", schema)?, op, col("b", schema)?, schema); + let data: Vec = vec![left.clone(), right.clone()]; + let batch = RecordBatch::try_new(schema.clone(), data)?; + let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + + assert_eq!(result.as_ref(), expected.as_ref()); + Ok(()) + } + + #[test] + fn arithmetic_decimal_expr_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )?) as ArrayRef; + let int32_array = Arc::new(Int32Array::from(vec![ + Some(123), + Some(122), + Some(123), + Some(124), + ])) as ArrayRef; + + // add: Int32array add decimal array + let expect = Arc::new(create_decimal_array( + &[Some(12423), None, Some(12422), Some(12524)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Plus, + expect, + ) + .unwrap(); + + // subtract: decimal array subtract int32 array + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(-12177), None, Some(-12178), Some(-12276)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Minus, + expect, + ) + .unwrap(); + + // multiply: decimal array multiply int32 array + let expect = Arc::new(create_decimal_array( + &[Some(15129), None, Some(15006), Some(15376)], + 21, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Multiply, + expect, + ) + .unwrap(); + // divide: int32 array divide decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[ + Some(10000000000000), + None, + Some(10081967213114), + Some(10000000000000), + ], + 23, + 11, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Divide, + expect, + ) + .unwrap(); + // modulus: int32 array modulus decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(000), None, Some(100), Some(000)], + 10, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Modulo, + expect, + ) + .unwrap(); + + Ok(()) + } }