Skip to content
This repository has been archived by the owner on Jan 13, 2023. It is now read-only.

Commit

Permalink
Add support for NULL literals (integer and decimal types) (#40)
Browse files Browse the repository at this point in the history
* Add NULL literal support for integer types

* Add NULL literal support for decimal type

Cargo fmt
  • Loading branch information
nseekhao authored Dec 15, 2022
1 parent adaccce commit 0964c39
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 12 deletions.
30 changes: 29 additions & 1 deletion src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use datafusion::{
};

use datafusion::sql::TableReference;
use substrait::protobuf::Type;
use substrait::protobuf::{
aggregate_function::AggregationInvocation,
expression::{
Expand All @@ -20,7 +21,7 @@ use substrait::protobuf::{
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
join_rel,
join_rel, r#type,
read_rel::ReadType,
rel::RelType,
sort_field::{SortDirection, SortKind::*},
Expand Down Expand Up @@ -596,6 +597,9 @@ pub async fn from_substrait_rex(
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some(
b.clone(),
))))),
Some(LiteralType::Null(ntype)) => {
Ok(Arc::new(Expr::Literal(from_substrait_null(ntype)?)))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal_type: {:?}",
Expand All @@ -608,3 +612,27 @@ pub async fn from_substrait_rex(
)),
}
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::I8(_) => Ok(ScalarValue::Int8(None)),
r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)),
r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)),
r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)),
r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(
None,
d.precision as u8,
d.scale as u8,
)),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported null kind: {:?}",
kind
))),
}
} else {
return Err(DataFusionError::NotImplemented(
"Null type without kind is not supported".to_string(),
));
}
}
65 changes: 54 additions & 11 deletions src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use substrait::protobuf::{
simple_extension_declaration::{ExtensionFunction, MappingType},
},
function_argument::ArgType,
join_rel, plan_rel,
join_rel, plan_rel, r#type,
read_rel::{NamedTable, ReadType},
rel::RelType,
sort_field::{SortDirection, SortKind},
Expand Down Expand Up @@ -359,9 +359,13 @@ pub fn to_substrait_agg_measure(
},
})
}
Expr::Alias(expr, _name) => {
to_substrait_agg_measure(expr, schema, extension_info)
}
_ => Err(DataFusionError::Internal(format!(
"Expression must be compatible with aggregation. Unsupported expression: {:?}",
expr
"Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}",
expr,
expr.variant_name()
))),
}
}
Expand Down Expand Up @@ -568,8 +572,8 @@ pub fn to_substrait_rex(
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
ScalarValue::Decimal128(v, p, s) => Some(LiteralType::Decimal(Decimal {
value: v.unwrap().to_le_bytes().to_vec(),
ScalarValue::Decimal128(Some(v), p, s) => Some(LiteralType::Decimal(Decimal {
value: v.to_le_bytes().to_vec(),
precision: *p as i32,
scale: *s as i32,
})),
Expand All @@ -578,12 +582,7 @@ pub fn to_substrait_rex(
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {:?}",
value
)))
}
_ => Some(try_to_substrait_null(value)?),
};
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
Expand All @@ -601,6 +600,50 @@ pub fn to_substrait_rex(
}
}

fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Decimal128(None, p, s) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
// TODO: Extend support for remaining data types
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {:?}",
v
))),
}
}

fn substrait_sort_field(
expr: &Expr,
schema: &DFSchemaRef,
Expand Down
16 changes: 16 additions & 0 deletions tests/roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ mod tests {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
}

#[tokio::test]
async fn null_decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b = NULL").await
}

#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
Expand Down Expand Up @@ -146,6 +151,17 @@ mod tests {
.await
}

#[tokio::test]
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
"SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
"Projection: SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
\n TableScan: data projection=[a]",
)
.await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
Expand Down

0 comments on commit 0964c39

Please sign in to comment.