From c66b9135a0c0a936586a65961c7174544b2f9309 Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao Date: Mon, 12 Dec 2022 15:03:43 -0800 Subject: [PATCH] Add NULL literal support for decimal type Cargo fmt --- src/consumer.rs | 5 +++++ src/producer.rs | 12 ++++++++++-- tests/roundtrip.rs | 5 +++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/consumer.rs b/src/consumer.rs index 5d5c7ca..d786a98 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -620,6 +620,11 @@ fn from_substrait_null(null_type: &Type) -> Result { r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)), r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)), r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)), + r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128( + None, + d.precision as u8, + d.scale as u8, + )), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported null kind: {:?}", kind diff --git a/src/producer.rs b/src/producer.rs index 46b58c3..6dc7cf4 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -572,8 +572,8 @@ pub fn to_substrait_rex( ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)), ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)), ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)), - ScalarValue::Decimal128(v, p, s) => Some(LiteralType::Decimal(Decimal { - value: v.unwrap().to_le_bytes().to_vec(), + ScalarValue::Decimal128(Some(v), p, s) => Some(LiteralType::Decimal(Decimal { + value: v.to_le_bytes().to_vec(), precision: *p as i32, scale: *s as i32, })), @@ -628,6 +628,14 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { nullability: default_nullability, })), })), + ScalarValue::Decimal128(None, p, s) => Ok(LiteralType::Null(substrait::protobuf::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + scale: *s as i32, + precision: *p as i32, + type_variation_reference: default_type_ref, + nullability: default_nullability, + })), + })), // TODO: Extend support for remaining data types _ => Err(DataFusionError::NotImplemented(format!( "Unsupported literal: {:?}", diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index fa05a97..11b2112 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -79,6 +79,11 @@ mod tests { roundtrip("SELECT * FROM data WHERE b > 2.5").await } + #[tokio::test] + async fn null_decimal_literal() -> Result<()> { + roundtrip("SELECT * FROM data WHERE b = NULL").await + } + #[tokio::test] async fn simple_distinct() -> Result<()> { test_alias(