Skip to content

Commit

Permalink
Extended datatypes & signatures support for NULLIF function (#4737)
Browse files Browse the repository at this point in the history
* extended nullif datatypes & signatures support

* sqllogictests & type inheritance
  • Loading branch information
korowa authored Dec 26, 2022
1 parent 8ec511e commit a8f1f8a
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 44 deletions.
98 changes: 98 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/nullif.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

statement ok
CREATE TABLE test(
int_field INT,
bool_field BOOLEAN,
text_field TEXT,
more_ints INT
) as VALUES
(1, true, 'abc', 2),
(2, false, 'def', 2),
(3, NULL, 'ghij', 3),
(NULL, NULL, NULL, 4),
(4, false, 'zxc', 5),
(NULL, true, NULL, 6)
;

# Arrays tests
query T
SELECT NULLIF(int_field, 2) FROM test;
----
1
NULL
3
NULL
4
NULL

query T
SELECT NULLIF(bool_field, false) FROM test;
----
true
NULL
NULL
NULL
NULL
true

query T
SELECT NULLIF(text_field, 'zxc') FROM test;
----
abc
def
ghij
NULL
NULL
NULL

query T
SELECT NULLIF(int_field, more_ints) FROM test;
----
1
NULL
NULL
NULL
4
NULL

query T
SELECT NULLIF(3, int_field) FROM test;
----
3
3
NULL
3
3
3

# Scalar values tests
query T
SELECT NULLIF(1, 1);
----
NULL

query T
SELECT NULLIF(1, 3);
----
1

query T
SELECT NULLIF(NULL, NULL);
----
NULL
2 changes: 2 additions & 0 deletions datafusion/expr/src/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];
151 changes: 107 additions & 44 deletions datafusion/physical-expr/src/expressions/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::Array;
use arrow::array::*;
use arrow::compute::eq_dyn;
use arrow::compute::nullif::nullif;
use arrow::datatypes::DataType;
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result};
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;

use super::binary::array_eq_scalar;

/// Invoke a compute kernel on a primitive array and a Boolean Array
macro_rules! compute_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
let rr = as_boolean_array($RIGHT).expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?) as ArrayRef)
}};
}

/// Binary op between primitive and boolean arrays
macro_rules! primitive_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for NULLIF/primitive/boolean operator",
other
))),
}
}};
}

/// Implements NULLIF(expr1, expr2)
/// Args: 0 - left expr is any array
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
Expand All @@ -79,7 +41,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
let cond_array = array_eq_scalar(lhs, rhs)?;

let array = primitive_bool_array_op!(lhs, &cond_array, nullif)?;
let array = nullif(lhs, as_boolean_array(&cond_array)?)?;

Ok(ColumnarValue::Array(array))
}
Expand All @@ -88,17 +50,34 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let cond_array = eq_dyn(lhs, rhs)?;

// Now, invoke nullif on the result
let array = primitive_bool_array_op!(lhs, &cond_array, nullif)?;
let array = nullif(lhs, as_boolean_array(&cond_array)?)?;
Ok(ColumnarValue::Array(array))
}
(ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
// Similar to Array-Array case, except of ScalarValue -> Array cast
let lhs = lhs.to_array_of_size(rhs.len());
let cond_array = eq_dyn(&lhs, rhs)?;

let array = nullif(&lhs, as_boolean_array(&cond_array)?)?;
Ok(ColumnarValue::Array(array))
}
_ => Err(DataFusionError::NotImplemented(
"nullif does not support a literal as first argument".to_string(),
)),
(ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
let val: ScalarValue = match lhs.eq(rhs) {
true => lhs.get_datatype().try_into()?,
false => lhs.clone(),
};

Ok(ColumnarValue::Scalar(val))
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::*;

use super::*;
use datafusion_common::{Result, ScalarValue};

Expand Down Expand Up @@ -162,4 +141,88 @@ mod tests {
assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}

#[test]
fn nullif_boolean() -> Result<()> {
let a = BooleanArray::from(vec![Some(true), Some(false), None]);
let a = ColumnarValue::Array(Arc::new(a));

let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));

let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);

let expected =
Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef;

assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}

#[test]
fn nullif_string() -> Result<()> {
let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
let a = ColumnarValue::Array(Arc::new(a));

let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string())));

let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);

let expected = Arc::new(StringArray::from(vec![
Some("foo"),
None,
None,
Some("baz"),
])) as ArrayRef;

assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}

#[test]
fn nullif_literal_first() -> Result<()> {
let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);
let a = ColumnarValue::Array(Arc::new(a));

let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));

let result = nullif_func(&[lit_array, a])?;
let result = result.into_array(0);

let expected = Arc::new(Int32Array::from(vec![
Some(2),
None,
Some(2),
Some(2),
Some(2),
Some(2),
])) as ArrayRef;
assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}

#[test]
fn nullif_scalar() -> Result<()> {
let a_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));

let result_eq = nullif_func(&[a_eq, b_eq])?;
let result_eq = result_eq.into_array(1);

let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef;

assert_eq!(expected_eq.as_ref(), result_eq.as_ref());

let a_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));

let result_neq = nullif_func(&[a_neq, b_neq])?;
let result_neq = result_neq.into_array(1);

let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef;
assert_eq!(expected_neq.as_ref(), result_neq.as_ref());

Ok(())
}
}

0 comments on commit a8f1f8a

Please sign in to comment.