diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 25e2e4b453a3..e4f425a95de8 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -112,6 +112,13 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +statement ok +CREATE TABLE flatten_table +AS VALUES + (make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]), make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3, 3.4])), + (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0])) +; + statement ok CREATE TABLE array_has_table_1D AS VALUES @@ -614,10 +621,8 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h' NULL NULL # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); ----- -NULL NULL # array_element scalar function #5 (with negative index) query IT @@ -724,16 +729,12 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); ----- -[1, 2, 3, 4] [h, e, l] # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); ----- -[2, 3, 4, 5] [l, l, o] # array_slice scalar function #10 (with zero-zero) query ?? @@ -742,10 +743,8 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', [] [] # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); ----- -[] [] # array_slice scalar function #12 (with zero and negative number) query ?? @@ -754,16 +753,12 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' [1] [h, e] # array_slice scalar function #13 (with negative number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); ----- -[2, 3, 4, 5] [l, l, o] # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); ----- -[1] [h, e] # array_slice scalar function #15 (with negative indexes) query ?? @@ -2319,6 +2314,30 @@ select array_concat(column1, [7]) from arrays_values_v2; [11, 12, 7] [7] +# flatten +query ??? +select flatten(make_array(1, 2, 1, 3, 2)), + flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), + flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); +---- +[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] + +query ???? +select column1, column2, column3, column4 from flatten_table; +---- +[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] +[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + +query ???? +select flatten(column1), + flatten(column2), + flatten(column3), + flatten(column4) +from flatten_table; +---- +[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ### Delete tables statement ok @@ -2371,3 +2390,6 @@ drop table arrays_with_repeating_elements; statement ok drop table nested_arrays_with_repeating_elements; + +statement ok +drop table flatten_table; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 061d0689cd97..27f6c6be1c46 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -163,6 +163,8 @@ pub enum BuiltinScalarFunction { Cardinality, /// construct an array from columns MakeArray, + /// Flatten + Flatten, // struct functions /// struct @@ -366,6 +368,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, + BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, @@ -499,6 +502,22 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { + BuiltinScalarFunction::Flatten => { + fn get_base_type(data_type: &DataType) -> Result { + match data_type { + DataType::List(field) => match field.data_type() { + DataType::List(_) => get_base_type(field.data_type()), + _ => Ok(data_type.to_owned()), + }, + _ => Err(DataFusionError::Internal( + "Not reachable, data_type should be List".to_string(), + )), + } + } + + let data_type = get_base_type(&input_expr_types[0])?; + Ok(data_type) + } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; @@ -817,11 +836,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } @@ -1305,6 +1325,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "list_element", "list_extract", ], + BuiltinScalarFunction::Flatten => &["flatten"], BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], BuiltinScalarFunction::ArrayHas => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ef6ce8171153..47767c23b363 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -564,6 +564,12 @@ scalar_expr!( first_array second_array, "Returns true if at least one element of the second array appears in the first array; otherwise, it returns false." ); +scalar_expr!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array." +); scalar_expr!( ArrayDims, array_dims, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 89e3217f730b..103bcd9ea435 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -92,6 +92,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + fun.return_type(&data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index a223a6998a39..ece0b1796b82 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1738,6 +1738,53 @@ pub fn cardinality(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +// Create new offsets that are euqiavlent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes.iter().map(|i| buffer[*i as usize]).collect(); + OffsetBuffer::new(offsets.into()) +} + +fn flatten_internal( + array: &dyn Array, + indexes: Option>, +) -> Result { + let list_arr = as_list_array(array)?; + let (field, offsets, values, nulls) = list_arr.clone().into_parts(); + let data_type = field.data_type(); + + match data_type { + // Recursively get the base offsets for flattened array + DataType::List(_) => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + flatten_internal(&values, Some(offsets)) + } else { + flatten_internal(&values, Some(offsets)) + } + } + // Reach the base level, create a new list array + _ => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + let list_arr = ListArray::new(field, offsets, values, nulls); + Ok(list_arr) + } else { + Ok(list_arr.clone()) + } + } + } +} + +/// Flatten SQL function +pub fn flatten(args: &[ArrayRef]) -> Result { + let flattened_array = flatten_internal(&args[0], None)?; + Ok(Arc::new(flattened_array) as ArrayRef) +} + /// Array_length SQL function pub fn array_length(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index df76d55bfcaa..d1a5119ee8a3 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -437,6 +437,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } + BuiltinScalarFunction::Flatten => { + Arc::new(|args| make_scalar_function(array_expressions::flatten)(args)) + } + BuiltinScalarFunction::ArrayNdims => { Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 26353d95fbe4..367613fb4dbe 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -594,6 +594,7 @@ enum ScalarFunction { ArrayRemoveAll = 109; ArrayReplaceAll = 110; Nanvl = 111; + Flatten = 112; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e87aa781b1f3..788dfa0feab5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18934,6 +18934,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayRemoveAll => "ArrayRemoveAll", Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", + Self::Flatten => "Flatten", }; serializer.serialize_str(variant) } @@ -19057,6 +19058,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll", "ArrayReplaceAll", "Nanvl", + "Flatten", ]; struct GeneratedVisitor; @@ -19211,6 +19213,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), + "Flatten" => Ok(ScalarFunction::Flatten), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2ec602eb36a7..eee27601822f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2378,6 +2378,7 @@ pub enum ScalarFunction { ArrayRemoveAll = 109, ArrayReplaceAll = 110, Nanvl = 111, + Flatten = 112, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2498,6 +2499,7 @@ impl ScalarFunction { ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", + ScalarFunction::Flatten => "Flatten", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2615,6 +2617,7 @@ impl ScalarFunction { "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), + "Flatten" => Some(Self::Flatten), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index bc865922d3af..c8b6ee79f984 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -457,6 +457,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayElement => Self::ArrayElement, + ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, ScalarFunction::ArrayPosition => Self::ArrayPosition, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9c41dcf3c501..92a3a02ed7f8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1436,6 +1436,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, + BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a04f43fd4b2b..88a5a73a6df3 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -188,6 +188,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index dec120db18c5..9bcf2ae0b09b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1685,6 +1685,24 @@ array_fill(element, array) Can be a constant, column, or function, and any combination of array operators. - **element**: Element to copy to the array. +### `flatten` + +Converts an array of arrays to a flat array + +- Applies to any depth of nested arrays +- Does not change arrays that are already flat + +The flattened array contains all the elements from all source arrays. + +#### Arguments + +- **array**: Array expression + Can be a constant, column, or function, and any combination of array operators. + +``` +flatten(array) +``` + ### `array_indexof` _Alias of [array_position](#array_position)._