Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support array flatten sql function #7239

Merged
merged 5 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ AS VALUES
(NULL, NULL, NULL, NULL)
;

statement ok
CREATE TABLE flatten_table
AS VALUES
(make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]), make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3, 3.4])),
(make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
;

statement ok
CREATE TABLE array_has_table_1D
AS VALUES
Expand Down Expand Up @@ -614,10 +621,8 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h'
NULL NULL

# array_element scalar function #4 (with NULL)
query error
query error
select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);
----
NULL NULL

# array_element scalar function #5 (with negative index)
query IT
Expand Down Expand Up @@ -724,16 +729,12 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h',
[1, 2, 3, 4] [h, e, l]

# array_slice scalar function #8 (with NULL and positive number)
query error
query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3);
----
[1, 2, 3, 4] [h, e, l]

# array_slice scalar function #9 (with positive number and NULL)
query error
query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
----
[2, 3, 4, 5] [l, l, o]

# array_slice scalar function #10 (with zero-zero)
query ??
Expand All @@ -742,10 +743,8 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h',
[] []

# array_slice scalar function #11 (with NULL-NULL)
query error
query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL);
----
[] []

# array_slice scalar function #12 (with zero and negative number)
query ??
Expand All @@ -754,16 +753,12 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h'
[1] [h, e]

# array_slice scalar function #13 (with negative number and NULL)
query error
query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
----
[2, 3, 4, 5] [l, l, o]

# array_slice scalar function #14 (with NULL and negative number)
query error
query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3);
----
[1] [h, e]

# array_slice scalar function #15 (with negative indexes)
query ??
Expand Down Expand Up @@ -2319,6 +2314,30 @@ select array_concat(column1, [7]) from arrays_values_v2;
[11, 12, 7]
[7]

# flatten
query ???
select flatten(make_array(1, 2, 1, 3, 2)),
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
----
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]

query ????
select column1, column2, column3, column4 from flatten_table;
----
[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]]
[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]

query ????
select flatten(column1),
flatten(column2),
flatten(column3),
flatten(column4)
from flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

### Delete tables

statement ok
Expand Down Expand Up @@ -2371,3 +2390,6 @@ drop table arrays_with_repeating_elements;

statement ok
drop table nested_arrays_with_repeating_elements;

statement ok
drop table flatten_table;
25 changes: 23 additions & 2 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ pub enum BuiltinScalarFunction {
Cardinality,
/// construct an array from columns
MakeArray,
/// Flatten
Flatten,

// struct functions
/// struct
Expand Down Expand Up @@ -366,6 +368,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplace => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
BuiltinScalarFunction::Flatten => Volatility::Immutable,
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
Expand Down Expand Up @@ -499,6 +502,22 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => get_base_type(field.data_type()),
_ => Ok(data_type.to_owned()),
},
_ => Err(DataFusionError::Internal(
"Not reachable, data_type should be List".to_string(),
)),
}
}

let data_type = get_base_type(&input_expr_types[0])?;
Ok(data_type)
}
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
Expand Down Expand Up @@ -817,11 +836,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayHasAll
| BuiltinScalarFunction::ArrayHasAny
| BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
Expand Down Expand Up @@ -1305,6 +1325,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
"list_element",
"list_extract",
],
BuiltinScalarFunction::Flatten => &["flatten"],
BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"],
BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"],
BuiltinScalarFunction::ArrayHas => {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,12 @@ scalar_expr!(
first_array second_array,
"Returns true if at least one element of the second array appears in the first array; otherwise, it returns false."
);
scalar_expr!(
Flatten,
flatten,
array,
"flattens an array of arrays into a single array."
);
scalar_expr!(
ArrayDims,
array_dims,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

fun.return_type(&data_types)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
Expand Down
47 changes: 47 additions & 0 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1738,6 +1738,53 @@ pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten(
offsets: OffsetBuffer<i32>,
indexes: OffsetBuffer<i32>,
) -> OffsetBuffer<i32> {
let buffer = offsets.into_inner();
let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as usize]).collect();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very nice implementation 👌

OffsetBuffer::new(offsets.into())
}

fn flatten_internal(
array: &dyn Array,
indexes: Option<OffsetBuffer<i32>>,
) -> Result<ListArray> {
let list_arr = as_list_array(array)?;
let (field, offsets, values, nulls) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
DataType::List(_) => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal(&values, Some(offsets))
} else {
flatten_internal(&values, Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = ListArray::new(field, offsets, values, nulls);
Ok(list_arr)
} else {
Ok(list_arr.clone())
}
}
}
}

/// Flatten SQL function
pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
let flattened_array = flatten_internal(&args[0], None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}

/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayLength => {
Arc::new(|args| make_scalar_function(array_expressions::array_length)(args))
}
BuiltinScalarFunction::Flatten => {
Arc::new(|args| make_scalar_function(array_expressions::flatten)(args))
}

BuiltinScalarFunction::ArrayNdims => {
Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ enum ScalarFunction {
ArrayRemoveAll = 109;
ArrayReplaceAll = 110;
Nanvl = 111;
Flatten = 112;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayHas => Self::ArrayHas,
ScalarFunction::ArrayDims => Self::ArrayDims,
ScalarFunction::ArrayElement => Self::ArrayElement,
ScalarFunction::Flatten => Self::Flatten,
ScalarFunction::ArrayLength => Self::ArrayLength,
ScalarFunction::ArrayNdims => Self::ArrayNdims,
ScalarFunction::ArrayPosition => Self::ArrayPosition,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
BuiltinScalarFunction::Flatten => Self::Flatten,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims,
BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition,
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` |
| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` |
| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` |
| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` |
| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` |
| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` |
| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` |
Expand Down
18 changes: 18 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,24 @@ array_fill(element, array)
Can be a constant, column, or function, and any combination of array operators.
- **element**: Element to copy to the array.

### `flatten`

Converts an array of arrays to a flat array

- Applies to any depth of nested arrays
- Does not change arrays that are already flat

The flattened array contains all the elements from all source arrays.

#### Arguments

- **array**: Array expression
Can be a constant, column, or function, and any combination of array operators.

```
flatten(array)
```

### `array_indexof`

_Alias of [array_position](#array_position)._
Expand Down