Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new concatenation operator for working with arrays #6615

Merged
merged 5 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,21 @@ query II rowsort
select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]);
----
3 21

# array concatenate operator #1 (like array_concat scalar function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding a test for `null handling in these arrays

like

select make_array(1, 2, 3) || make_array(4, null, 6);

and

select make_array(1, 2, 3) || null;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this feature will only be available after issue #6556 is resolved.

query ?? rowsort
select make_array(1, 2, 3) || make_array(4, 5, 6) || make_array(7, 8, 9), make_array([1], [2]) || make_array([3], [4]);
----
[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]]

# array concatenate operator #2 (like array_append scalar function)
query ??? rowsort
select make_array(1, 2, 3) || 4, make_array(1.0, 2.0, 3.0) || 4.0, make_array('h', 'e', 'l', 'l') || 'o';
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]

# array concatenate operator #3 (like array_prepend scalar function)
query ??? rowsort
select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_array('e', 'l', 'l', 'o');
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
6 changes: 6 additions & 0 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
// TODO: cast between array elements (#6558)
(List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
_ => None,
})
}
Expand All @@ -697,6 +699,10 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
(LargeUtf8, Utf8) => Some(LargeUtf8),
(Utf8, LargeUtf8) => Some(LargeUtf8),
(LargeUtf8, LargeUtf8) => Some(LargeUtf8),
// TODO: cast between array elements (#6558)
(List(_), List(_)) => Some(lhs_type.clone()),
(List(_), _) => Some(lhs_type.clone()),
(_, List(_)) => Some(rhs_type.clone()),
_ => None,
}
}
Expand Down
223 changes: 67 additions & 156 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::buffer::Buffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field};
use core::any::type_name;
use datafusion_common::cast::as_list_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
Expand Down Expand Up @@ -166,44 +167,28 @@ macro_rules! append {
let child_array =
downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE);
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
let concat = compute::concat(&[child_array, element])?;
let cat = compute::concat(&[child_array, element])?;
let mut scalars = vec![];
for i in 0..concat.len() {
scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(
&concat, i,
)?));
for i in 0..cat.len() {
scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&cat, i)?));
}
scalars
}};
}

/// Array_append SQL function
pub fn array_append(args: &[ColumnarValue]) -> Result<ColumnarValue> {
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"Array_append function requires two arguments, got {}",
args.len()
)));
}

let arr = match &args[0] {
ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
ColumnarValue::Array(arr) => arr.clone(),
};

let element = match &args[1] {
ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
_ => {
return Err(DataFusionError::Internal(
"Array_append function requires scalar element".to_string(),
))
}
};
let arr = as_list_array(&args[0])?;
let element = &args[1];

let data_type = arr.data_type();
let arrays = match data_type {
DataType::List(field) => {
match (field.data_type(), element.data_type()) {
let scalars = match (arr.value_type(), element.data_type()) {
(DataType::Utf8, DataType::Utf8) => append!(arr, element, StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, element, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => append!(arr, element, BooleanArray),
Expand All @@ -222,61 +207,38 @@ pub fn array_append(args: &[ColumnarValue]) -> Result<ColumnarValue> {
"Array_append is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'."
)))
}
}
}
data_type => {
return Err(DataFusionError::Internal(format!(
"Array is not type '{data_type:?}'."
)))
}
};

array(arrays.as_slice())
Ok(array(scalars.as_slice())?.into_array(1))
}

macro_rules! prepend {
($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{
let child_array =
downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE);
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
let concat = compute::concat(&[element, child_array])?;
let cat = compute::concat(&[element, child_array])?;
let mut scalars = vec![];
for i in 0..concat.len() {
scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(
&concat, i,
)?));
for i in 0..cat.len() {
scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&cat, i)?));
}
scalars
}};
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ColumnarValue]) -> Result<ColumnarValue> {
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"Array_prepend function requires two arguments, got {}",
args.len()
)));
}

let element = match &args[0] {
ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
_ => {
return Err(DataFusionError::Internal(
"Array_prepend function requires scalar element".to_string(),
))
}
};

let arr = match &args[1] {
ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
ColumnarValue::Array(arr) => arr.clone(),
};
let element = &args[0];
let arr = as_list_array(&args[1])?;

let data_type = arr.data_type();
let arrays = match data_type {
DataType::List(field) => {
match (field.data_type(), element.data_type()) {
let scalars = match (arr.value_type(), element.data_type()) {
(DataType::Utf8, DataType::Utf8) => prepend!(arr, element, StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, element, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => prepend!(arr, element, BooleanArray),
Expand All @@ -295,57 +257,33 @@ pub fn array_prepend(args: &[ColumnarValue]) -> Result<ColumnarValue> {
"Array_prepend is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'."
)))
}
}
}
data_type => {
return Err(DataFusionError::Internal(format!(
"Array is not type '{data_type:?}'."
)))
}
};

array(arrays.as_slice())
Ok(array(scalars.as_slice())?.into_array(1))
}

/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let arrays: Vec<ArrayRef> = args
.iter()
.map(|x| match x {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
})
.collect();
let data_type = arrays[0].data_type();
match data_type {
DataType::List(..) => {
let list_arrays =
downcast_vec!(arrays, ListArray).collect::<Result<Vec<&ListArray>>>()?;
let len: usize = list_arrays.iter().map(|a| a.values().len()).sum();
let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
let array_data: Vec<_> =
list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data = array_data.iter().collect();
let mut mutable =
MutableArrayData::with_capacities(array_data, false, capacity);

for (i, a) in list_arrays.iter().enumerate() {
mutable.extend(i, 0, a.len())
}
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_arrays =
downcast_vec!(args, ListArray).collect::<Result<Vec<&ListArray>>>()?;
let len: usize = list_arrays.iter().map(|a| a.values().len()).sum();
let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
let array_data: Vec<_> = list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data = array_data.iter().collect();
let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);

for (i, a) in list_arrays.iter().enumerate() {
mutable.extend(i, 0, a.len())
}

let builder = mutable.into_builder();
let list = builder
.len(1)
.buffers(vec![Buffer::from_slice_ref([0, len as i32])])
.build()
.unwrap();
let builder = mutable.into_builder();
let list = builder
.len(1)
.buffers(vec![Buffer::from_slice_ref([0, len as i32])])
.build()
.unwrap();

return Ok(ColumnarValue::Array(Arc::new(make_array(list))));
}
_ => Err(DataFusionError::NotImplemented(format!(
"Array is not type '{data_type:?}'."
))),
}
return Ok(Arc::new(make_array(list)));
}

macro_rules! fill {
Expand Down Expand Up @@ -1096,6 +1034,7 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result<ColumnarValue> {
mod tests {
use super::*;
use arrow::array::UInt8Array;
use arrow::datatypes::Int64Type;
use datafusion_common::cast::{
as_generic_string_array, as_list_array, as_uint64_array, as_uint8_array,
};
Expand Down Expand Up @@ -1161,21 +1100,15 @@ mod tests {
#[test]
fn test_array_append() {
// array_append([1, 2, 3], 4) = [1, 2, 3, 4]
let args = [
ColumnarValue::Scalar(ScalarValue::List(
Some(vec![
ScalarValue::Int64(Some(1)),
ScalarValue::Int64(Some(2)),
ScalarValue::Int64(Some(3)),
]),
Arc::new(Field::new("item", DataType::Int64, false)),
)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
];
let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let int64_array = Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef;

let array = array_append(&args)
.expect("failed to initialize function array_append")
.into_array(1);
let args = [list_array, int64_array];

let array =
array_append(&args).expect("failed to initialize function array_append");
let result =
as_list_array(&array).expect("failed to initialize function array_append");

Expand All @@ -1193,21 +1126,15 @@ mod tests {
#[test]
fn test_array_prepend() {
// array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4]
let args = [
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::List(
Some(vec![
ScalarValue::Int64(Some(2)),
ScalarValue::Int64(Some(3)),
ScalarValue::Int64(Some(4)),
]),
Arc::new(Field::new("item", DataType::Int64, false)),
)),
];
let data = vec![Some(vec![Some(2), Some(3), Some(4)])];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let int64_array = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;

let array = array_prepend(&args)
.expect("failed to initialize function array_append")
.into_array(1);
let args = [int64_array, list_array];

let array =
array_prepend(&args).expect("failed to initialize function array_append");
let result =
as_list_array(&array).expect("failed to initialize function array_append");

Expand All @@ -1225,36 +1152,20 @@ mod tests {
#[test]
fn test_array_concat() {
// array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9]
let args = [
ColumnarValue::Scalar(ScalarValue::List(
Some(vec![
ScalarValue::Int64(Some(1)),
ScalarValue::Int64(Some(2)),
ScalarValue::Int64(Some(3)),
]),
Arc::new(Field::new("item", DataType::Int64, false)),
)),
ColumnarValue::Scalar(ScalarValue::List(
Some(vec![
ScalarValue::Int64(Some(4)),
ScalarValue::Int64(Some(5)),
ScalarValue::Int64(Some(6)),
]),
Arc::new(Field::new("item", DataType::Int64, false)),
)),
ColumnarValue::Scalar(ScalarValue::List(
Some(vec![
ScalarValue::Int64(Some(7)),
ScalarValue::Int64(Some(8)),
ScalarValue::Int64(Some(9)),
]),
Arc::new(Field::new("item", DataType::Int64, false)),
)),
];

let array = array_concat(&args)
.expect("failed to initialize function array_concat")
.into_array(1);
let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
let list_array1 =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let data = vec![Some(vec![Some(4), Some(5), Some(6)])];
let list_array2 =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let data = vec![Some(vec![Some(7), Some(8), Some(9)])];
let list_array3 =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;

let args = [list_array1, list_array2, list_array3];

let array =
array_concat(&args).expect("failed to initialize function array_concat");
let result =
as_list_array(&array).expect("failed to initialize function array_concat");

Expand Down
10 changes: 7 additions & 3 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ use self::kernels_arrow::{
};

use super::column::Column;
use crate::array_expressions::{array_append, array_concat, array_prepend};
use crate::expressions::cast_column;
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::intervals::{apply_operator, Interval};
Expand Down Expand Up @@ -1252,9 +1253,12 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => {
binary_string_array_op!(left, right, concat_elements)
}
StringConcat => match (left_data_type, right_data_type) {
(DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
(DataType::List(_), _) => array_append(&[left, right]),
(_, DataType::List(_)) => array_prepend(&[left, right]),
_ => binary_string_array_op!(left, right, concat_elements),
},
}
}
}
Expand Down
Loading