diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 04ae32ec35aa..4356f36b18d8 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -34,6 +34,7 @@ use arrow::{ }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; +use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -65,6 +66,11 @@ pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) } +// Downcast ArrayRef to Decimal256Array +pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { + Ok(downcast_value!(array, Decimal256Array)) +} + // Downcast ArrayRef to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 99ff5f3384d4..4a7767023fed 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -26,14 +26,14 @@ use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::cast::{ - as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array, - as_fixed_size_list_array, as_list_array, as_struct_array, + as_decimal128_array, as_decimal256_array, as_dictionary_array, + as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array, as_struct_array, }; use crate::delta::shift_months; use crate::error::{DataFusionError, Result}; use arrow::buffer::NullBuffer; use arrow::compute::nullif; -use arrow::datatypes::{FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -47,6 +47,7 @@ use arrow::{ }, }; use arrow_array::timezone::Tz; +use arrow_array::ArrowNativeTypeOp; use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; // Constants we use throughout this file: @@ -75,6 +76,8 @@ pub enum ScalarValue { Float64(Option), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), + /// 256bit decimal, using the i256 to represent the decimal, precision scale + Decimal256(Option, u8, i8), /// signed 8bit int Int8(Option), /// signed 16bit int @@ -160,6 +163,10 @@ impl PartialEq for ScalarValue { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } (Decimal128(_, _, _), _) => false, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal256(_, _, _), _) => false, (Boolean(v1), Boolean(v2)) => v1.eq(v2), (Boolean(_), _) => false, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -283,6 +290,15 @@ impl PartialOrd for ScalarValue { } } (Decimal128(_, _, _), _) => None, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal256(_, _, _), _) => None, (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), (Boolean(_), _) => None, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -1038,6 +1054,7 @@ macro_rules! impl_op_arithmetic { get_sign!($OPERATION), true, )))), + // todo: Add Decimal256 support _ => Err(DataFusionError::Internal(format!( "Operator {} is not implemented for types {:?} and {:?}", stringify!($OPERATION), @@ -1516,6 +1533,11 @@ impl std::hash::Hash for ScalarValue { p.hash(state); s.hash(state) } + Decimal256(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Boolean(v) => v.hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), @@ -1994,6 +2016,9 @@ impl ScalarValue { ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } + ScalarValue::Decimal256(_, precision, scale) => { + DataType::Decimal256(*precision, *scale) + } ScalarValue::TimestampSecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) } @@ -2083,6 +2108,9 @@ impl ScalarValue { ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) } + ScalarValue::Decimal256(Some(v), precision, scale) => Ok( + ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), + ), value => Err(DataFusionError::Internal(format!( "Can not run arithmetic negative on scalar value {value:?}" ))), @@ -2154,6 +2182,7 @@ impl ScalarValue { ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), + ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), ScalarValue::Int16(v) => v.is_none(), ScalarValue::Int32(v) => v.is_none(), @@ -2415,10 +2444,10 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Decimal256(_, _) => { - return Err(DataFusionError::Internal( - "Decimal256 is not supported for ScalarValue".to_string(), - )); + DataType::Decimal256(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) } DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), @@ -2680,6 +2709,22 @@ impl ScalarValue { Ok(array) } + fn iter_to_decimal256_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal256(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::() + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + fn iter_to_array_list( scalars: impl IntoIterator, data_type: &DataType, @@ -2764,12 +2809,28 @@ impl ScalarValue { } } + fn build_decimal256_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Decimal256Array { + std::iter::repeat(value) + .take(size) + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap() + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( ScalarValue::build_decimal_array(*e, *precision, *scale, size), ), + ScalarValue::Decimal256(e, precision, scale) => Arc::new( + ScalarValue::build_decimal256_array(*e, *precision, *scale, size), + ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef } @@ -3044,12 +3105,28 @@ impl ScalarValue { precision: u8, scale: i8, ) -> Result { - let array = as_decimal128_array(array)?; - if array.is_null(index) { - Ok(ScalarValue::Decimal128(None, precision, scale)) - } else { - let value = array.value(index); - Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + match array.data_type() { + DataType::Decimal128(_, _) => { + let array = as_decimal128_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal128(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + } + } + DataType::Decimal256(_, _) => { + let array = as_decimal256_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal256(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal256(Some(value), precision, scale)) + } + } + _ => Err(DataFusionError::Internal( + "Unsupported decimal type".to_string(), + )), } } @@ -3067,6 +3144,11 @@ impl ScalarValue { array, index, *precision, *scale, )? } + DataType::Decimal256(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), @@ -3265,6 +3347,25 @@ impl ScalarValue { } } + fn eq_array_decimal256( + array: &ArrayRef, + index: usize, + value: Option<&i256>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal256_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + /// Compares a single row of array @ index for equality with self, /// in an optimized fashion. /// @@ -3294,6 +3395,16 @@ impl ScalarValue { ) .unwrap() } + ScalarValue::Decimal256(v, precision, scale) => { + ScalarValue::eq_array_decimal256( + array, + index, + v.as_ref(), + *precision, + *scale, + ) + .unwrap() + } ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val) } @@ -3416,6 +3527,7 @@ impl ScalarValue { | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) | ScalarValue::Int16(_) | ScalarValue::Int32(_) @@ -3647,6 +3759,22 @@ impl TryFrom for i128 { } } +// special implementation for i256 because of Decimal128 +impl TryFrom for i256 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal256(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); @@ -3684,6 +3812,9 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), DataType::Binary => ScalarValue::Binary(None), @@ -3753,6 +3884,9 @@ impl fmt::Display for ScalarValue { ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } + ScalarValue::Decimal256(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Boolean(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, @@ -3830,6 +3964,7 @@ impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), + ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 4a3d39bdebcf..5c82c7e0091d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -180,23 +180,29 @@ drop table foo statement ok create table foo as select - arrow_cast(100, 'Decimal128(5,2)') as col_d128 - -- Can't make a decimal 156: - -- This feature is not implemented: Can't create a scalar from array of type "Decimal256(3, 2)" - --arrow_cast(100, 'Decimal256(5,2)') as col_d256 + arrow_cast(100, 'Decimal128(5,2)') as col_d128, + arrow_cast(100, 'Decimal256(5,2)') as col_d256 ; ## Ensure each column in the table has the expected type -query T +query TT SELECT - arrow_typeof(col_d128) - -- arrow_typeof(col_d256), + arrow_typeof(col_d128), + arrow_typeof(col_d256) FROM foo; ---- -Decimal128(5, 2) +Decimal128(5, 2) Decimal256(5, 2) + +query RR +SELECT + col_d128, + col_d256 + FROM foo; +---- +100 100.00 statement ok drop table foo diff --git a/datafusion/core/tests/sqllogictests/test_files/decimal.slt b/datafusion/core/tests/sqllogictests/test_files/decimal.slt index f41351774172..8fd08f87c849 100644 --- a/datafusion/core/tests/sqllogictests/test_files/decimal.slt +++ b/datafusion/core/tests/sqllogictests/test_files/decimal.slt @@ -612,3 +612,12 @@ insert into foo VALUES (1, 5); query error DataFusion error: Arrow error: Compute error: Overflow happened on: 100000000000000000000 \* 100000000000000000000000000000000000000 select a / b from foo; + +statement ok +create table t as values (arrow_cast(123, 'Decimal256(5,2)')); + +query error DataFusion error: Internal error: Operator \+ is not implemented for types Decimal256\(None,15,2\) and Decimal256\(Some\(12300\),15,2\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +select AVG(column1) from t; + +statement ok +drop table t; diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 1fccdcbd2ca1..dec2eb7f1238 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,6 +17,7 @@ use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use datafusion_common::{DataFusionError, Result}; use std::ops::Deref; @@ -360,6 +361,12 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } DataType::Dictionary(_, dict_value_type) => { sum_return_type(dict_value_type.as_ref()) } @@ -423,6 +430,13 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(dict_value_type.as_ref()) @@ -441,6 +455,11 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } + DataType::Decimal256(precision, scale) => { + // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index a1d77a2d8849..9c01093edf5f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -77,12 +77,12 @@ impl Avg { // the internal sum data type of avg just support FLOAT64 and Decimal data type. assert!(matches!( sum_data_type, - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) )); // the result of avg just support FLOAT64 and Decimal data type. assert!(matches!( rt_data_type, - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) )); Self { name: name.into(), diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 45e2be7fb4c6..9ac90cef4bab 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -28,6 +28,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::array::Array; use arrow::array::Decimal128Array; +use arrow::array::Decimal256Array; use arrow::compute; use arrow::compute::kernels::cast; use arrow::datatypes::DataType; @@ -39,8 +40,8 @@ use arrow::{ datatypes::Field, }; use arrow_array::types::{ - Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type, - UInt64Type, + Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, Int64Type, + UInt32Type, UInt64Type, }; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -169,6 +170,10 @@ impl AggregateExpr for Sum { instantiate_primitive_accumulator!(self, Decimal128Type, |x, y| x .add_assign(y)) } + DataType::Decimal256(_, _) => { + instantiate_primitive_accumulator!(self, Decimal256Type, |x, y| *x = + *x + y) + } _ => Err(DataFusionError::NotImplemented(format!( "GroupsAccumulator not supported for {}: {}", self.name, self.data_type @@ -250,6 +255,16 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: i8) -> Result Result { + let array = downcast_value!(values, Decimal256Array); + let result = compute::sum(array); + Ok(ScalarValue::Decimal256(result, precision, scale)) +} + // sums the array and returns a ScalarValue of its corresponding type. pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { // TODO refine the cast kernel in arrow-rs @@ -263,6 +278,9 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { sum_decimal_batch(values, *precision, *scale)? } + DataType::Decimal256(precision, scale) => { + sum_decimal256_batch(values, *precision, *scale)? + } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8192a403d33e..f7247effdd2a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -908,6 +908,8 @@ message ScalarValue{ //WAS: ScalarType null_list_value = 18; Decimal128 decimal128_value = 20; + Decimal256 decimal256_value = 39; + int64 date_64_value = 21; int32 interval_yearmonth_value = 24; int64 interval_daytime_value = 25; @@ -934,6 +936,12 @@ message Decimal128{ int64 s = 3; } +message Decimal256{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + // Serialized data type message ArrowType{ oneof arrow_type_enum { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 05bfbd089dfe..aaf6bb97bb1a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4983,6 +4983,137 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Decimal256 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; + if !self.value.is_empty() { + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Decimal256") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -19125,6 +19256,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } + scalar_value::Value::Decimal256Value(v) => { + struct_ser.serialize_field("decimal256Value", v)?; + } scalar_value::Value::Date64Value(v) => { struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; } @@ -19218,6 +19352,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "listValue", "decimal128_value", "decimal128Value", + "decimal256_value", + "decimal256Value", "date_64_value", "date64Value", "interval_yearmonth_value", @@ -19270,6 +19406,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time32Value, ListValue, Decimal128Value, + Decimal256Value, Date64Value, IntervalYearmonthValue, IntervalDaytimeValue, @@ -19324,6 +19461,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), "listValue" | "list_value" => Ok(GeneratedField::ListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), + "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), @@ -19471,6 +19609,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("decimal128Value")); } value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) +; + } + GeneratedField::Decimal256Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal256Value")); + } + value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value) ; } GeneratedField::Date64Value => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f50754494d1d..e1ad6acec832 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1097,7 +1097,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1146,6 +1146,8 @@ pub mod scalar_value { ListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), + #[prost(message, tag = "39")] + Decimal256Value(super::Decimal256), #[prost(int64, tag = "21")] Date64Value(i64), #[prost(int32, tag = "24")] @@ -1188,6 +1190,16 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} /// Serialized data type #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 674588692d98..71a1bf87db6e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -26,7 +26,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, }; use datafusion::execution::registry::FunctionRegistry; @@ -648,6 +648,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { val.s as i8, ) } + Value::Decimal256Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal256( + Some(i256::from_be_bytes(array)), + val.p as u8, + val.s as i8, + ) + } Value::Date64Value(v) => Self::Date64(Some(*v)), Value::Time32Value(v) => { let time_value = diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 072bc84d5452..f1a961576128 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1148,6 +1148,24 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }), }, + ScalarValue::Decimal256(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal256Value(protobuf::Decimal256 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, ScalarValue::Date64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) }