From f4882120da3242d620c31f1e474ac6e5899a293e Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 2 Nov 2022 12:42:53 +1300 Subject: [PATCH] Support dictionary in InList (#3936) --- datafusion/core/tests/sql/predicates.rs | 99 ++++++++++++++++++- datafusion/physical-expr/Cargo.toml | 1 + .../physical-expr/src/expressions/in_list.rs | 30 ++++-- 3 files changed, 119 insertions(+), 11 deletions(-) diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 3eea9430004b..21d46f7c8f71 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -428,8 +428,101 @@ async fn csv_in_set_test() -> Result<()> { } #[tokio::test] -#[ignore] -// https://github.com/apache/arrow-datafusion/issues/3936 +async fn in_list_string_dictionaries() -> Result<()> { + // let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")] + let input = vec![Some("foo"), Some("bar"), Some("fazzz")] + .into_iter() + .collect::>(); + + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let sql = "SELECT * FROM test WHERE c1 IN ('Bar')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| foo |", + "| fazzz |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn in_list_string_dictionaries_with_null() -> Result<()> { + let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")] + .into_iter() + .collect::>(); + + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let sql = "SELECT * FROM test WHERE c1 IN ('Bar')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| foo |", + "| fazzz |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] async fn in_set_string_dictionaries() -> Result<()> { let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")] .into_iter() @@ -440,7 +533,7 @@ async fn in_set_string_dictionaries() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("test", batch)?; - let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazz')"; + let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 6fc6f4176a46..0b2d72dd3d27 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -54,6 +54,7 @@ hashbrown = { version = "0.12", features = ["raw"] } itertools = { version = "0.10", features = ["use_std"] } lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } +num-traits = { version = "0.2", default-features = false } ordered-float = "3.0" paste = "^1.0" rand = "0.8" diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 9406b42ee6f8..26503bec7241 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -27,10 +27,11 @@ use crate::physical_expr::down_cast_any_ref; use crate::utils::expr_list_eq_any_order; use crate::PhysicalExpr; use arrow::array::*; +use arrow::compute::take; use arrow::datatypes::*; -use arrow::downcast_primitive_array; use arrow::record_batch::RecordBatch; use arrow::util::bit_iterator::BitIndexIterator; +use arrow::{downcast_dictionary_array, downcast_primitive_array}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use hashbrown::hash_map::RawEntryMut; @@ -57,7 +58,7 @@ impl Debug for InListExpr { /// A type-erased container of array elements trait Set: Send + Sync { - fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray; + fn contains(&self, v: &dyn Array, negated: bool) -> Result; } struct ArrayHashSet { @@ -92,13 +93,22 @@ where for<'a> &'a T: ArrayAccessor, for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue, { - fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray { + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(BooleanArray::from(result.data().clone())) + } + _ => {} + } + let v = v.as_any().downcast_ref::().unwrap(); let in_data = self.array.data(); let in_array = &self.array; let has_nulls = in_data.null_count() != 0; - ArrayIter::new(v) + Ok(ArrayIter::new(v) .map(|v| { v.and_then(|v| { let hash = v.hash_one(&self.hash_set.state); @@ -116,7 +126,7 @@ where } }) }) - .collect() + .collect()) } } @@ -188,10 +198,12 @@ fn make_set(array: &dyn Array) -> Result> { let array = as_generic_binary_array::(array); Box::new(ArraySet::new(array, make_hash_set(array))) } + DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"), d => return Err(DataFusionError::NotImplemented(format!("DataType::{} not supported in InList", d))) }) } +/// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], batch: &RecordBatch, @@ -203,6 +215,8 @@ fn evaluate_list( ColumnarValue::Array(_) => Err(DataFusionError::Execution( "InList expression must evaluate to a scalar".to_string(), )), + // Flatten dictionary values + ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v), ColumnarValue::Scalar(s) => Ok(s), }) }) @@ -286,10 +300,10 @@ impl PhysicalExpr for InListExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?.into_array(1); let r = match &self.static_filter { - Some(f) => f.contains(value.as_ref(), self.negated), + Some(f) => f.contains(value.as_ref(), self.negated)?, None => { let list = evaluate_list(&self.list, batch)?; - make_set(list.as_ref())?.contains(value.as_ref(), self.negated) + make_set(list.as_ref())?.contains(value.as_ref(), self.negated)? } }; Ok(ColumnarValue::Array(Arc::new(r))) @@ -947,7 +961,7 @@ mod tests { let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); let array = Int64Array::from(vec![1, 2, 3, 4]); - let r = result.contains(&array, false); + let r = result.contains(&array, false).unwrap(); assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();