Skip to content
This repository has been archived by the owner on Jan 13, 2023. It is now read-only.

Add support for NULL literals (integer and decimal types) #40

Merged
merged 2 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use datafusion::{
};

use datafusion::sql::TableReference;
use substrait::protobuf::Type;
use substrait::protobuf::{
aggregate_function::AggregationInvocation,
expression::{
Expand All @@ -20,7 +21,7 @@ use substrait::protobuf::{
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
join_rel,
join_rel, r#type,
read_rel::ReadType,
rel::RelType,
sort_field::{SortDirection, SortKind::*},
Expand Down Expand Up @@ -596,6 +597,9 @@ pub async fn from_substrait_rex(
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some(
b.clone(),
))))),
Some(LiteralType::Null(ntype)) => {
Ok(Arc::new(Expr::Literal(from_substrait_null(ntype)?)))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal_type: {:?}",
Expand All @@ -608,3 +612,27 @@ pub async fn from_substrait_rex(
)),
}
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::I8(_) => Ok(ScalarValue::Int8(None)),
r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)),
r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)),
r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)),
r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(
None,
d.precision as u8,
d.scale as u8,
)),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported null kind: {:?}",
kind
))),
}
} else {
return Err(DataFusionError::NotImplemented(
"Null type without kind is not supported".to_string(),
));
}
}
65 changes: 54 additions & 11 deletions src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use substrait::protobuf::{
simple_extension_declaration::{ExtensionFunction, MappingType},
},
function_argument::ArgType,
join_rel, plan_rel,
join_rel, plan_rel, r#type,
read_rel::{NamedTable, ReadType},
rel::RelType,
sort_field::{SortDirection, SortKind},
Expand Down Expand Up @@ -359,9 +359,13 @@ pub fn to_substrait_agg_measure(
},
})
}
Expr::Alias(expr, _name) => {
to_substrait_agg_measure(expr, schema, extension_info)
}
_ => Err(DataFusionError::Internal(format!(
"Expression must be compatible with aggregation. Unsupported expression: {:?}",
expr
"Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}",
expr,
expr.variant_name()
))),
}
}
Expand Down Expand Up @@ -568,8 +572,8 @@ pub fn to_substrait_rex(
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
ScalarValue::Decimal128(v, p, s) => Some(LiteralType::Decimal(Decimal {
value: v.unwrap().to_le_bytes().to_vec(),
ScalarValue::Decimal128(Some(v), p, s) => Some(LiteralType::Decimal(Decimal {
value: v.to_le_bytes().to_vec(),
precision: *p as i32,
scale: *s as i32,
})),
Expand All @@ -578,12 +582,7 @@ pub fn to_substrait_rex(
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {:?}",
value
)))
}
_ => Some(try_to_substrait_null(value)?),
};
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
Expand All @@ -601,6 +600,50 @@ pub fn to_substrait_rex(
}
}

fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Decimal128(None, p, s) => Ok(LiteralType::Null(substrait::protobuf::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
// TODO: Extend support for remaining data types
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {:?}",
v
))),
}
}

fn substrait_sort_field(
expr: &Expr,
schema: &DFSchemaRef,
Expand Down
16 changes: 16 additions & 0 deletions tests/roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ mod tests {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
}

#[tokio::test]
async fn null_decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b = NULL").await
}

#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
Expand Down Expand Up @@ -146,6 +151,17 @@ mod tests {
.await
}

#[tokio::test]
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
"SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
"Projection: SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
\n TableScan: data projection=[a]",
)
.await
}

#[tokio::test]
async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
Expand Down