diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 6ebde09ee811..3c94b5bf411a 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -357,3 +357,39 @@ query ? select make_array(x, y) from foo2; ---- [1.0, 1] + +# array_contains scalar function #1 +query BBB rowsort +select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]); +---- +true true true + +# array_contains scalar function #2 +query BB rowsort +select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 3]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 4]); +---- +true true + +# array_contains scalar function #3 +query BBB rowsort +select array_contains(make_array(1, 2, 3), make_array(1, 2, 3, 4)), array_contains([1, 2, 3], [1, 1, 4]), array_contains([1, 2, 3], [2, 1, 3, 4]); +---- +false false false + +# array_contains scalar function #4 +query BB rowsort +select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 5]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 5]); +---- +false false + +# array_contains scalar function #5 +query BB rowsort +select array_contains([true, true, false, true, false], [true, false, false]), array_contains([true, false, true], [true, true]); +---- +true true + +# array_contains scalar function #6 +query BB rowsort +select array_contains(make_array(true, true, true), make_array(false, false)), array_contains([false, false, false], [true, true]); +---- +false false diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index ef8892ee4953..677743f5b320 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -113,6 +113,8 @@ pub enum BuiltinScalarFunction { ArrayAppend, /// array_concat ArrayConcat, + /// array_contains + ArrayContains, /// array_dims ArrayDims, /// array_fill @@ -319,6 +321,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, + BuiltinScalarFunction::ArrayContains => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, BuiltinScalarFunction::ArrayFill => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, @@ -460,6 +463,7 @@ impl BuiltinScalarFunction { "The {self} function can only accept fixed size list as the args." ))), }, + BuiltinScalarFunction::ArrayContains => Ok(Boolean), BuiltinScalarFunction::ArrayDims => Ok(UInt8), BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( "item", @@ -741,6 +745,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayContains => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { @@ -1166,6 +1171,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { // array functions BuiltinScalarFunction::ArrayAppend => &["array_append"], BuiltinScalarFunction::ArrayConcat => &["array_concat"], + BuiltinScalarFunction::ArrayContains => &["array_contains"], BuiltinScalarFunction::ArrayDims => &["array_dims"], BuiltinScalarFunction::ArrayFill => &["array_fill"], BuiltinScalarFunction::ArrayLength => &["array_length"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ef782b319cd7..a45cf0febaa0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -530,6 +530,13 @@ scalar_expr!( "appends an element to the end of an array." ); nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); +scalar_expr!( + ArrayContains, + array_contains, + first_array second_array, +"returns true, if each element of the second array appe + aring in the first array, otherwise false." +); scalar_expr!( ArrayDims, array_dims, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ffb1eddf57f9..a4b0327d8d36 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -26,6 +26,7 @@ use datafusion_common::cast::as_list_array; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use itertools::Itertools; use std::sync::Arc; macro_rules! downcast_vec { @@ -1070,6 +1071,70 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result { ])))) } +macro_rules! contains { + ($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $ARRAY_TYPE:ident) => {{ + let first_array = downcast_arg!($FIRST_ARRAY, $ARRAY_TYPE); + let second_array = downcast_arg!($SECOND_ARRAY, $ARRAY_TYPE); + let mut res = true; + for x in second_array.values().iter().dedup() { + if !first_array.values().contains(x) { + res = false; + } + } + + res + }}; +} + +/// Array_contains SQL function +pub fn array_contains(args: &[ArrayRef]) -> Result { + fn concat_inner_lists(arg: ArrayRef) -> Result { + match arg.data_type() { + DataType::List(field) => match field.data_type() { + DataType::List(..) => { + concat_inner_lists(array_concat(&[as_list_array(&arg)? + .values() + .clone()])?) + } + _ => Ok(as_list_array(&arg)?.values().clone()), + }, + data_type => Err(DataFusionError::NotImplemented(format!( + "Array is not type '{data_type:?}'." + ))), + } + } + + let concat_first_array = concat_inner_lists(args[0].clone())?.clone(); + let concat_second_array = concat_inner_lists(args[1].clone())?.clone(); + + let res = match (concat_first_array.data_type(), concat_second_array.data_type()) { + (DataType::Utf8, DataType::Utf8) => contains!(concat_first_array, concat_second_array, StringArray), + (DataType::LargeUtf8, DataType::LargeUtf8) => contains!(concat_first_array, concat_second_array, LargeStringArray), + (DataType::Boolean, DataType::Boolean) => { + let first_array = downcast_arg!(concat_first_array, BooleanArray); + let second_array = downcast_arg!(concat_second_array, BooleanArray); + compute::bool_or(first_array) == compute::bool_or(second_array) + } + (DataType::Float32, DataType::Float32) => contains!(concat_first_array, concat_second_array, Float32Array), + (DataType::Float64, DataType::Float64) => contains!(concat_first_array, concat_second_array, Float64Array), + (DataType::Int8, DataType::Int8) => contains!(concat_first_array, concat_second_array, Int8Array), + (DataType::Int16, DataType::Int16) => contains!(concat_first_array, concat_second_array, Int16Array), + (DataType::Int32, DataType::Int32) => contains!(concat_first_array, concat_second_array, Int32Array), + (DataType::Int64, DataType::Int64) => contains!(concat_first_array, concat_second_array, Int64Array), + (DataType::UInt8, DataType::UInt8) => contains!(concat_first_array, concat_second_array, UInt8Array), + (DataType::UInt16, DataType::UInt16) => contains!(concat_first_array, concat_second_array, UInt16Array), + (DataType::UInt32, DataType::UInt32) => contains!(concat_first_array, concat_second_array, UInt32Array), + (DataType::UInt64, DataType::UInt64) => contains!(concat_first_array, concat_second_array, UInt64Array), + (first_array_data_type, second_array_data_type) => { + return Err(DataFusionError::NotImplemented(format!( + "Array_contains is not implemented for types '{first_array_data_type:?}' and '{second_array_data_type:?}'." + ))) + } + }; + + Ok(Arc::new(BooleanArray::from(vec![res]))) +} + #[cfg(test)] mod tests { use super::*; @@ -1588,7 +1653,7 @@ mod tests { #[test] fn test_array_ndims() { - // array_ndims([1, 2]) = 1 + // array_ndims([1, 2, 3, 4]) = 1 let list_array = return_array(); let array = array_ndims(&[list_array]) @@ -1602,7 +1667,7 @@ mod tests { #[test] fn test_nested_array_ndims() { - // array_ndims([[1, 2], [3, 4]]) = 2 + // array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 let list_array = return_nested_array(); let array = array_ndims(&[list_array]) @@ -1614,6 +1679,63 @@ mod tests { assert_eq!(result, &UInt8Array::from(vec![2])); } + #[test] + fn test_array_contains() { + // array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 3)) = t + let first_array = return_array().into_array(1); + let second_array = array_append(&[ + first_array.clone(), + Arc::new(Int64Array::from(vec![Some(3)])), + ]) + .expect("failed to initialize function array_contains"); + + let arr = array_contains(&[first_array.clone(), second_array]) + .expect("failed to initialize function array_contains"); + let result = as_boolean_array(&arr); + + assert_eq!(result, &BooleanArray::from(vec![true])); + + // array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 5)) = f + let second_array = array_append(&[ + first_array.clone(), + Arc::new(Int64Array::from(vec![Some(5)])), + ]) + .expect("failed to initialize function array_contains"); + + let arr = array_contains(&[first_array.clone(), second_array]) + .expect("failed to initialize function array_contains"); + let result = as_boolean_array(&arr); + + assert_eq!(result, &BooleanArray::from(vec![false])); + } + + #[test] + fn test_nested_array_contains() { + // array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 3)) = t + let first_array = return_nested_array().into_array(1); + let array = return_array().into_array(1); + let second_array = + array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(3)]))]) + .expect("failed to initialize function array_contains"); + + let arr = array_contains(&[first_array.clone(), second_array]) + .expect("failed to initialize function array_contains"); + let result = as_boolean_array(&arr); + + assert_eq!(result, &BooleanArray::from(vec![true])); + + // array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 9)) = f + let second_array = + array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(9)]))]) + .expect("failed to initialize function array_contains"); + + let arr = array_contains(&[first_array.clone(), second_array]) + .expect("failed to initialize function array_contains"); + let result = as_boolean_array(&arr); + + assert_eq!(result, &BooleanArray::from(vec![false])); + } + fn return_array() -> ColumnarValue { let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 37dd492e9e18..016e8bf766f4 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -391,6 +391,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } + BuiltinScalarFunction::ArrayContains => { + Arc::new(|args| make_scalar_function(array_expressions::array_contains)(args)) + } BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims), BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7fddd31a6ba5..4cc80c207ca4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -563,6 +563,7 @@ enum ScalarFunction { ArrayToString = 97; Cardinality = 98; TrimArray = 99; + ArrayContains = 100; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4c1bab5e397c..42397e3da239 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -17883,6 +17883,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayToString => "ArrayToString", Self::Cardinality => "Cardinality", Self::TrimArray => "TrimArray", + Self::ArrayContains => "ArrayContains", }; serializer.serialize_str(variant) } @@ -17994,6 +17995,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayToString", "Cardinality", "TrimArray", + "ArrayContains", ]; struct GeneratedVisitor; @@ -18136,6 +18138,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayToString" => Ok(ScalarFunction::ArrayToString), "Cardinality" => Ok(ScalarFunction::Cardinality), "TrimArray" => Ok(ScalarFunction::TrimArray), + "ArrayContains" => Ok(ScalarFunction::ArrayContains), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8dfc209477ef..31086deead1a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2224,6 +2224,7 @@ pub enum ScalarFunction { ArrayToString = 97, Cardinality = 98, TrimArray = 99, + ArrayContains = 100, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2332,6 +2333,7 @@ impl ScalarFunction { ScalarFunction::ArrayToString => "ArrayToString", ScalarFunction::Cardinality => "Cardinality", ScalarFunction::TrimArray => "TrimArray", + ScalarFunction::ArrayContains => "ArrayContains", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2437,6 +2439,7 @@ impl ScalarFunction { "ArrayToString" => Some(Self::ArrayToString), "Cardinality" => Some(Self::Cardinality), "TrimArray" => Some(Self::TrimArray), + "ArrayContains" => Some(Self::ArrayContains), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ab2985f448a8..5fabf17e327e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,12 +36,12 @@ use datafusion_common::{ }; use datafusion_expr::expr::Placeholder; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill, - array_length, array_ndims, array_position, array_positions, array_prepend, - array_remove, array_replace, array_to_string, ascii, asin, asinh, atan, atan2, atanh, - bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, - concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, date_trunc, degrees, - digest, exp, + abs, acos, acosh, array, array_append, array_concat, array_contains, array_dims, + array_fill, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_replace, array_to_string, ascii, asin, asinh, + atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, + chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, + date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -450,6 +450,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, ScalarFunction::ArrayConcat => Self::ArrayConcat, + ScalarFunction::ArrayContains => Self::ArrayContains, ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayFill => Self::ArrayFill, ScalarFunction::ArrayLength => Self::ArrayLength, @@ -1192,6 +1193,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayContains => Ok(array_contains( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayFill => Ok(array_fill( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b3e5bd0fa6c2..bf233c229f7f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, + BuiltinScalarFunction::ArrayContains => Self::ArrayContains, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayFill => Self::ArrayFill, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 07f5923a6a34..b5d4cfa0ac9e 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -179,23 +179,24 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Array Expressions -| Function | Notes | -| ------------------------------------ | --------------------------------------------------------------- | -| array_append(array, element) | Appends an element to the end of an array. | -| array_concat(array[, ..., array_n]) | Concatenates arrays. | -| array_dims(array) | Returns an array of the array's dimensions. | -| array_fill(element, array) | Returns an array filled with copies of the given value. | -| array_length(array, dimension) | Returns the length of the array dimension. | -| array_ndims(array) | Returns the number of dimensions of the array. | -| array_position(array, element) | Searches for an element in the array, returns first occurrence. | -| array_positions(array, element) | Searches for an element in the array, returns all occurrences. | -| array_prepend(array, element) | Prepends an element to the beginning of an array. | -| array_remove(array, element) | Removes all elements equal to the given value from the array. | -| array_replace(array, from, to) | Replaces a specified element with another specified element. | -| array_to_string(array, delimeter) | Converts each element to its text representation. | -| cardinality(array) | Returns the total number of elements in the array. | -| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. | -| trim_array(array, n) | Removes the last n elements from the array. | +| Function | Notes | +| ----------------------------------------- | ------------------------------------------------------------------------------------------------ | +| array_append(array, element) | Appends an element to the end of an array. | +| array_concat(array[, ..., array_n]) | Concatenates arrays. | +| array_contains(first_array, second_array) | Returns true, if each element of the second array appearing in the first array, otherwise false. | +| array_dims(array) | Returns an array of the array's dimensions. | +| array_fill(element, array) | Returns an array filled with copies of the given value. | +| array_length(array, dimension) | Returns the length of the array dimension. | +| array_ndims(array) | Returns the number of dimensions of the array. | +| array_position(array, element) | Searches for an element in the array, returns first occurrence. | +| array_positions(array, element) | Searches for an element in the array, returns all occurrences. | +| array_prepend(array, element) | Prepends an element to the beginning of an array. | +| array_remove(array, element) | Removes all elements equal to the given value from the array. | +| array_replace(array, from, to) | Replaces a specified element with another specified element. | +| array_to_string(array, delimeter) | Converts each element to its text representation. | +| cardinality(array) | Returns the total number of elements in the array. | +| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. | +| trim_array(array, n) | Removes the last n elements from the array. | ## Regular Expressions diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c24b6dc91af1..34999ddf168b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1382,6 +1382,7 @@ from_unixtime(expression) - [array_append](#array_append) - [array_concat](#array_concat) +- [array_contains](#array_contains) - [array_dims](#array_dims) - [array_fill](#array_fill) - [array_length](#array_length) @@ -1424,6 +1425,21 @@ array_concat(array[, ..., array_n]) Can be a constant, column, or function, and any combination of array operators. - **array_n**: Subsequent array column or literal array to concatenate. +### `array_contains` + +Returns true, if each element of the second array appears in the first array, otherwise false. + +``` +array_contains(first_array, second_array) +``` + +#### Arguments + +- **first_array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **second_array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + ### `array_dims` Returns an array of the array's dimensions.