Skip to content

Commit

Permalink
Tests for eq_array, and bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Aug 9, 2021
1 parent 8dec8cc commit 89e61ef
Showing 1 changed file with 172 additions and 13 deletions.
185 changes: 172 additions & 13 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,24 +278,28 @@ impl std::hash::Hash for ScalarValue {
}

// return the index into the dictionary values for array@index as well
// as the dictioanry values
// as a reference to the dictionary values array. Returns None for the
// index if the array is NULL at index
#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> Result<(&ArrayRef, usize)> {
) -> Result<(&ArrayRef, Option<usize>)> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_array.keys();
if !keys_col.is_valid(index) {
return Ok((dict_array.values(), None));
}
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;

Ok((dict_array.values(), values_index))
Ok((dict_array.values(), Some(values_index)))
}

macro_rules! typed_cast {
Expand Down Expand Up @@ -1023,17 +1027,16 @@ impl ScalarValue {
Self::try_from_array(dict_array.values(), values_index)
}

/// Compares array @ index for equality with self.
/// Compares array @ index for equality with self, in an optimized fashion
///
/// This method implements an optimized version of:
///
/// TODO: optimize: avoid constructing an intermediate ScalarValue
/// ``text
/// let arr_scalar = Self::try_from_array(array, index).unwrap();
/// arr_scalar.eq(self)
/// ```
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool {
// This code is an optimized version of this:
//
// let arr_scalar = Self::try_from_array(array, index).unwrap();
// arr_scalar.eq(self)

// handle dictionary arrays specially
if let DataType::Dictionary(key_type, _) = array.data_type() {
return self.eq_array_dictionary(array, index, key_type);
}
Expand Down Expand Up @@ -1101,7 +1104,7 @@ impl ScalarValue {
}

/// Compares a dictionary array with indexes of type `key_type`
/// with the array @ index for equality with selfñ
/// with the array @ index for equality with self
fn eq_array_dictionary(
&self,
array: &ArrayRef,
Expand All @@ -1120,7 +1123,10 @@ impl ScalarValue {
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
};

self.eq_array(values, values_index)
match values_index {
Some(values_index) => self.eq_array(values, values_index),
None => self.is_null(),
}
}
}

Expand Down Expand Up @@ -1786,6 +1792,159 @@ mod tests {
assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
}

#[test]
fn scalar_eq_array() {
// Validate that eq_array has the same semantics as ScalarValue::eq
macro_rules! make_typed_vec {
($INPUT:expr, $TYPE:ident) => {{
$INPUT
.iter()
.map(|v| v.map(|v| v as $TYPE))
.collect::<Vec<_>>()
}};
}

let bool_vals = vec![Some(true), None, Some(false)];
let f32_vals = vec![Some(-1.0), None, Some(1.0)];
let f64_vals = make_typed_vec!(f32_vals, f64);

let i8_vals = vec![Some(-1), None, Some(1)];
let i16_vals = make_typed_vec!(i8_vals, i16);
let i32_vals = make_typed_vec!(i8_vals, i32);
let i64_vals = make_typed_vec!(i8_vals, i64);

let u8_vals = vec![Some(0), None, Some(1)];
let u16_vals = make_typed_vec!(u8_vals, u16);
let u32_vals = make_typed_vec!(u8_vals, u32);
let u64_vals = make_typed_vec!(u8_vals, u64);

let str_vals = vec![Some("foo"), None, Some("bar")];

/// Test each value in `scalar` with the corresponding element
/// at `array`. Assumes each element is unique (aka not equal
/// with all other indexes)
struct TestCase {
array: ArrayRef,
scalars: Vec<ScalarValue>,
}

/// Create a test case for casing the input to the specified array type
macro_rules! make_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()),
scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(),
}
}};
}

macro_rules! make_str_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

macro_rules! make_binary_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| {
ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec()))
})
.collect(),
}
}};
}

/// create a test case for DictionaryArray<$INDEX_TY>
macro_rules! make_str_dict_test_case {
($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new(
$INPUT
.iter()
.cloned()
.collect::<DictionaryArray<$INDEX_TY>>(),
),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

let cases = vec![
make_test_case!(bool_vals, BooleanArray, Boolean),
make_test_case!(f32_vals, Float32Array, Float32),
make_test_case!(f64_vals, Float64Array, Float64),
make_test_case!(i8_vals, Int8Array, Int8),
make_test_case!(i16_vals, Int16Array, Int16),
make_test_case!(i32_vals, Int32Array, Int32),
make_test_case!(i64_vals, Int64Array, Int64),
make_test_case!(u8_vals, UInt8Array, UInt8),
make_test_case!(u16_vals, UInt16Array, UInt16),
make_test_case!(u32_vals, UInt32Array, UInt32),
make_test_case!(u64_vals, UInt64Array, UInt64),
make_str_test_case!(str_vals, StringArray, Utf8),
make_str_test_case!(str_vals, LargeStringArray, LargeUtf8),
make_binary_test_case!(str_vals, BinaryArray, Binary),
make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary),
make_test_case!(i32_vals, Date32Array, Date32),
make_test_case!(i64_vals, Date64Array, Date64),
make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond),
make_test_case!(i64_vals, TimestampMillisecondArray, TimestampMillisecond),
make_test_case!(i64_vals, TimestampMicrosecondArray, TimestampMicrosecond),
make_test_case!(i64_vals, TimestampNanosecondArray, TimestampNanosecond),
make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth),
make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime),
make_str_dict_test_case!(str_vals, Int8Type, Utf8),
make_str_dict_test_case!(str_vals, Int16Type, Utf8),
make_str_dict_test_case!(str_vals, Int32Type, Utf8),
make_str_dict_test_case!(str_vals, Int64Type, Utf8),
make_str_dict_test_case!(str_vals, UInt8Type, Utf8),
make_str_dict_test_case!(str_vals, UInt16Type, Utf8),
make_str_dict_test_case!(str_vals, UInt32Type, Utf8),
make_str_dict_test_case!(str_vals, UInt64Type, Utf8),
];

for case in cases {
let TestCase { array, scalars } = case;
assert_eq!(array.len(), scalars.len());

for (index, scalar) in scalars.into_iter().enumerate() {
assert!(
scalar.eq_array(&array, index),
"Expected {:?} to be equal to {:?} at index {}",
scalar,
array,
index
);

// test that all other elements are *not* equal
for other_index in 0..array.len() {
if index != other_index {
assert!(
!scalar.eq_array(&array, other_index),
"Expected {:?} to be NOT equal to {:?} at index {}",
scalar,
array,
other_index
);
}
}
}
}
}

#[test]
fn scalar_partial_ordering() {
use ScalarValue::*;
Expand Down

0 comments on commit 89e61ef

Please sign in to comment.