diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 4088414296f9..25e2e4b453a3 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -43,6 +43,20 @@ CREATE TABLE values( (8, 15, 16, 8.8, NULL, '') ; +statement ok +CREATE TABLE values_without_nulls +AS VALUES + (1, 1, 2, 1.1, 'Lorem', 'A'), + (2, 3, 4, 2.2, 'ipsum', ''), + (3, 5, 6, 3.3, 'dolor', 'BB'), + (4, 7, 8, 4.4, 'sit', NULL), + (5, 9, 10, 5.5, 'amet', 'CCC'), + (6, 11, 12, 6.6, ',', 'DD'), + (7, 13, 14, 7.7, 'consectetur', 'E'), + (8, 15, 16, 8.8, 'adipiscing', 'F'), + (9, 17, 18, 9.9, 'elit', '') +; + statement ok CREATE TABLE arrays AS VALUES @@ -996,25 +1010,71 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] -## array_fill +## array_repeat (aliases: `list_repeat`) -# array_fill scalar function #1 +# array_repeat scalar function #1 query ??? -select array_fill(11, make_array(1, 2, 3)), array_fill(3, make_array(2, 3)), array_fill(2, make_array(2)); +select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4); ---- -[[[11, 11, 11], [11, 11, 11]]] [[3, 3, 3], [3, 3, 3]] [2, 2] +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] -# array_fill scalar function #2 -query ?? -select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, 2)); +# array_repeat scalar function #2 (element as list) +query ??? +select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2); +---- +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] + +# list_repeat scalar function #3 (function alias: `array_repeat`) +query ??? +select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4); ---- -[[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]] +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] -# array_fill scalar function #3 +# array_repeat with columns #1 +query ? +select array_repeat(column4, column1) from values_without_nulls; +---- +[1.1] +[2.2, 2.2] +[3.3, 3.3, 3.3] +[4.4, 4.4, 4.4, 4.4] +[5.5, 5.5, 5.5, 5.5, 5.5] +[6.6, 6.6, 6.6, 6.6, 6.6, 6.6] +[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7] +[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8] +[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] + +# array_repeat with columns #2 (element as list) query ? -select array_fill(1, make_array()) +select array_repeat(column1, column3) from arrays_values_without_nulls; ---- -[] +[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] +[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] +[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] +[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] + +# array_repeat with columns and scalars #1 +query ?? +select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls; +---- +[1] [1.1, 1.1, 1.1] +[1, 1] [2.2, 2.2, 2.2] +[1, 1, 1] [3.3, 3.3, 3.3] +[1, 1, 1, 1] [4.4, 4.4, 4.4] +[1, 1, 1, 1, 1] [5.5, 5.5, 5.5] +[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6] +[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7] +[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8] +[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9] + +# array_repeat with columns and scalars #2 (element as list) +query ?? +select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls; +---- +[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] +[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] +[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] +[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -1570,7 +1630,7 @@ h,e,l,l,o 1-2-3-4-5 1|2|3 # array_to_string scalar function #2 query TTT -select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_fill(3, [3, 2, 2]), '/\'); +select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_repeat(array_repeat(array_repeat(3, 2), 2), 3), '/\'); ---- 11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3 @@ -1670,7 +1730,7 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali # cardinality scalar function #2 query II -select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3, array[3, 2, 3])); +select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); ---- 6 18 @@ -1883,10 +1943,10 @@ select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, NULL NULL 2 # array_length scalar function #4 -query IIII -select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, 2, 5]), 2), array_length(array_fill(3, [3, 2, 5]), 3), array_length(array_fill(3, [3, 2, 5]), 4); +query II +select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- -3 2 5 NULL +3 2 # array_length scalar function #5 query III @@ -1936,7 +1996,7 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), # array_dims scalar function #2 query ?? -select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4])); +select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); ---- [1, 2, 3] [2, 5, 4] @@ -1974,7 +2034,7 @@ select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])) # array_ndims scalar function #2 query II -select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ---- 3 21 @@ -2264,6 +2324,9 @@ select array_concat(column1, [7]) from arrays_values_v2; statement ok drop table values; +statement ok +drop table values_without_nulls; + statement ok drop table nested_arrays; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 2886617c0d45..061d0689cd97 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -131,8 +131,6 @@ pub enum BuiltinScalarFunction { ArrayDims, /// array_element ArrayElement, - /// array_fill - ArrayFill, /// array_length ArrayLength, /// array_ndims @@ -149,6 +147,8 @@ pub enum BuiltinScalarFunction { ArrayRemoveN, /// array_remove_all ArrayRemoveAll, + /// array_repeat + ArrayRepeat, /// array_replace ArrayReplace, /// array_replace_n @@ -354,12 +354,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, - BuiltinScalarFunction::ArrayFill => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, BuiltinScalarFunction::ArrayPrepend => Volatility::Immutable, + BuiltinScalarFunction::ArrayRepeat => Volatility::Immutable, BuiltinScalarFunction::ArrayRemove => Volatility::Immutable, BuiltinScalarFunction::ArrayRemoveN => Volatility::Immutable, BuiltinScalarFunction::ArrayRemoveAll => Volatility::Immutable, @@ -536,11 +536,6 @@ impl BuiltinScalarFunction { "The {self} function can only accept list as the first argument" ))), }, - BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[1].clone(), - true, - )))), BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), @@ -548,6 +543,11 @@ impl BuiltinScalarFunction { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()), + BuiltinScalarFunction::ArrayRepeat => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), BuiltinScalarFunction::ArrayRemove => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayRemoveN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayRemoveAll => Ok(input_expr_types[0].clone()), @@ -822,7 +822,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } @@ -832,6 +831,7 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), @@ -1310,7 +1310,6 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::ArrayHas => { &["array_has", "list_has", "array_contains", "list_contains"] } - BuiltinScalarFunction::ArrayFill => &["array_fill"], BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], BuiltinScalarFunction::ArrayPosition => &[ @@ -1326,6 +1325,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "array_push_front", "list_push_front", ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6d0b6c1d6535..ef6ce8171153 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -576,12 +576,6 @@ scalar_expr!( array element, "extracts the element with the index n from the array." ); -scalar_expr!( - ArrayFill, - array_fill, - element array, - "returns an array filled with copies of the given value." -); scalar_expr!( ArrayLength, array_length, @@ -612,6 +606,12 @@ scalar_expr!( array element, "prepends an element to the beginning of an array." ); +scalar_expr!( + ArrayRepeat, + array_repeat, + element count, + "returns an array containing element `count` times." +); scalar_expr!( ArrayRemove, array_remove, @@ -1062,12 +1062,12 @@ mod test { test_scalar_expr!(ArrayAppend, array_append, array, element); test_unary_scalar_expr!(ArrayDims, array_dims); - test_scalar_expr!(ArrayFill, array_fill, element, array); test_scalar_expr!(ArrayLength, array_length, array, dimension); test_unary_scalar_expr!(ArrayNdims, array_ndims); test_scalar_expr!(ArrayPosition, array_position, array, element, index); test_scalar_expr!(ArrayPositions, array_positions, array, element); test_scalar_expr!(ArrayPrepend, array_prepend, array, element); + test_scalar_expr!(ArrayRepeat, array_repeat, element, count); test_scalar_expr!(ArrayRemove, array_remove, array, element); test_scalar_expr!(ArrayRemoveN, array_remove_n, array, element, max); test_scalar_expr!(ArrayRemoveAll, array_remove_all, array, element); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 74eca0b8cd74..a223a6998a39 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -826,84 +826,169 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { concat_internal(new_args.as_slice()) } -macro_rules! fill { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); +macro_rules! general_repeat { + ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - let mut acc = ColumnarValue::Scalar($ELEMENT); - for value in arr.iter().rev() { - match value { - Some(value) => { - let mut repeated = vec![]; - for _ in 0..value { - repeated.push(acc.clone()); - } - acc = array(repeated.as_slice()).unwrap(); + let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE); + for (el, c) in element_array.iter().zip($COUNT.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match el { + Some(el) => { + let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; + let repeated_array = + [Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>(); + + values = downcast_arg!( + compute::concat(&[&values, &repeated_array])?.clone(), + $ARRAY_TYPE + ) + .clone(); + offsets.push(last_offset + repeated_array.len() as i32); } - _ => { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires non nullable array" - ))); + None => { + offsets.push(last_offset); } } } - acc + let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + + Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?) }}; } -/// Array_fill SQL function -pub fn array_fill(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires two arguments, got {}", - args.len() - ))); - } +macro_rules! general_repeat_list { + ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone(); - let element = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_fill function requires scalar element".to_string(), - )) - } - }; + let element_array = downcast_arg!($ELEMENT, ListArray); + for (el, c) in element_array.iter().zip($COUNT.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match el { + Some(el) => { + let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; + let repeated_vec = vec![el; c]; - let arr = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + let mut i: i32 = 0; + let mut repeated_offsets = vec![i]; + repeated_offsets.extend( + repeated_vec + .clone() + .into_iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); - let res = match arr.data_type() { - DataType::List(..) => { - let arr = downcast_arg!(arr, ListArray); - let array_values = arr.values(); - match arr.value_type() { - DataType::Int8 => fill!(array_values, element, Int8Array), - DataType::Int16 => fill!(array_values, element, Int16Array), - DataType::Int32 => fill!(array_values, element, Int32Array), - DataType::Int64 => fill!(array_values, element, Int64Array), - DataType::UInt8 => fill!(array_values, element, UInt8Array), - DataType::UInt16 => fill!(array_values, element, UInt16Array), - DataType::UInt32 => fill!(array_values, element, UInt32Array), - DataType::UInt64 => fill!(array_values, element, UInt64Array), - DataType::Null => { - return Ok(datafusion_expr::ColumnarValue::Scalar( - ScalarValue::new_list(Some(vec![]), DataType::Null), - )) + let mut repeated_values = downcast_arg!( + new_empty_array(&element_array.value_type()), + $ARRAY_TYPE + ) + .clone(); + for repeated_list in repeated_vec { + repeated_values = downcast_arg!( + compute::concat(&[&repeated_values, &repeated_list])?, + $ARRAY_TYPE + ) + .clone(); + } + + let field = Arc::new(Field::new( + "item", + element_array.value_type().clone(), + true, + )); + let repeated_array = ListArray::try_new( + field, + OffsetBuffer::new(repeated_offsets.clone().into()), + Arc::new(repeated_values), + None, + )?; + + values = downcast_arg!( + compute::concat(&[&values, &repeated_array,])?.clone(), + ListArray + ) + .clone(); + offsets.push(last_offset + repeated_array.len() as i32); } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array_fill is not implemented for type '{data_type:?}'." - ))); + None => { + offsets.push(last_offset); } } } + + let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + + Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?) + }}; +} + +/// Array_repeat SQL function +pub fn array_repeat(args: &[ArrayRef]) -> Result { + let element = &args[0]; + let count = as_int64_array(&args[1])?; + + let res = match element.data_type() { + DataType::List(field) => match field.data_type() { + DataType::List(_) => general_repeat_list!(element, count, ListArray), + DataType::Utf8 => general_repeat_list!(element, count, StringArray), + DataType::LargeUtf8 => general_repeat_list!(element, count, LargeStringArray), + DataType::Boolean => general_repeat_list!(element, count, BooleanArray), + DataType::Float32 => general_repeat_list!(element, count, Float32Array), + DataType::Float64 => general_repeat_list!(element, count, Float64Array), + DataType::Int8 => general_repeat_list!(element, count, Int8Array), + DataType::Int16 => general_repeat_list!(element, count, Int16Array), + DataType::Int32 => general_repeat_list!(element, count, Int32Array), + DataType::Int64 => general_repeat_list!(element, count, Int64Array), + DataType::UInt8 => general_repeat_list!(element, count, UInt8Array), + DataType::UInt16 => general_repeat_list!(element, count, UInt16Array), + DataType::UInt32 => general_repeat_list!(element, count, UInt32Array), + DataType::UInt64 => general_repeat_list!(element, count, UInt64Array), + data_type => { + return Err(DataFusionError::NotImplemented(format!( + "Array_repeat is not implemented for types 'List({data_type:?})'." + ))) + } + }, + DataType::Utf8 => general_repeat!(element, count, StringArray), + DataType::LargeUtf8 => general_repeat!(element, count, LargeStringArray), + DataType::Boolean => general_repeat!(element, count, BooleanArray), + DataType::Float32 => general_repeat!(element, count, Float32Array), + DataType::Float64 => general_repeat!(element, count, Float64Array), + DataType::Int8 => general_repeat!(element, count, Int8Array), + DataType::Int16 => general_repeat!(element, count, Int16Array), + DataType::Int32 => general_repeat!(element, count, Int32Array), + DataType::Int64 => general_repeat!(element, count, Int64Array), + DataType::UInt8 => general_repeat!(element, count, UInt8Array), + DataType::UInt16 => general_repeat!(element, count, UInt16Array), + DataType::UInt32 => general_repeat!(element, count, UInt32Array), + DataType::UInt64 => general_repeat!(element, count, UInt64Array), data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))); + return Err(DataFusionError::NotImplemented(format!( + "Array_repeat is not implemented for types '{data_type:?}'." + ))) } }; @@ -964,31 +1049,24 @@ pub fn array_position(args: &[ArrayRef]) -> Result { Int64Array::from_value(0, arr.len()) }; - let res = match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::List(_) => position!(arr, element, index, ListArray), - DataType::Utf8 => position!(arr, element, index, StringArray), - DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), - DataType::Boolean => position!(arr, element, index, BooleanArray), - DataType::Float32 => position!(arr, element, index, Float32Array), - DataType::Float64 => position!(arr, element, index, Float64Array), - DataType::Int8 => position!(arr, element, index, Int8Array), - DataType::Int16 => position!(arr, element, index, Int16Array), - DataType::Int32 => position!(arr, element, index, Int32Array), - DataType::Int64 => position!(arr, element, index, Int64Array), - DataType::UInt8 => position!(arr, element, index, UInt8Array), - DataType::UInt16 => position!(arr, element, index, UInt16Array), - DataType::UInt32 => position!(arr, element, index, UInt32Array), - DataType::UInt64 => position!(arr, element, index, UInt64Array), - data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array_position is not implemented for types '{data_type:?}'." - ))) - } - }, + let res = match arr.value_type() { + DataType::List(_) => position!(arr, element, index, ListArray), + DataType::Utf8 => position!(arr, element, index, StringArray), + DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), + DataType::Boolean => position!(arr, element, index, BooleanArray), + DataType::Float32 => position!(arr, element, index, Float32Array), + DataType::Float64 => position!(arr, element, index, Float64Array), + DataType::Int8 => position!(arr, element, index, Int8Array), + DataType::Int16 => position!(arr, element, index, Int16Array), + DataType::Int32 => position!(arr, element, index, Int32Array), + DataType::Int64 => position!(arr, element, index, Int64Array), + DataType::UInt8 => position!(arr, element, index, UInt8Array), + DataType::UInt16 => position!(arr, element, index, UInt16Array), + DataType::UInt32 => position!(arr, element, index, UInt32Array), + DataType::UInt64 => position!(arr, element, index, UInt64Array), data_type => { return Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." + "Array_position is not implemented for types '{data_type:?}'." ))) } }; @@ -1050,31 +1128,24 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[0])?; let element = &args[1]; - let res = match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::List(_) => positions!(arr, element, ListArray), - DataType::Utf8 => positions!(arr, element, StringArray), - DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), - DataType::Boolean => positions!(arr, element, BooleanArray), - DataType::Float32 => positions!(arr, element, Float32Array), - DataType::Float64 => positions!(arr, element, Float64Array), - DataType::Int8 => positions!(arr, element, Int8Array), - DataType::Int16 => positions!(arr, element, Int16Array), - DataType::Int32 => positions!(arr, element, Int32Array), - DataType::Int64 => positions!(arr, element, Int64Array), - DataType::UInt8 => positions!(arr, element, UInt8Array), - DataType::UInt16 => positions!(arr, element, UInt16Array), - DataType::UInt32 => positions!(arr, element, UInt32Array), - DataType::UInt64 => positions!(arr, element, UInt64Array), - data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array_positions is not implemented for types '{data_type:?}'." - ))) - } - }, + let res = match arr.value_type() { + DataType::List(_) => positions!(arr, element, ListArray), + DataType::Utf8 => positions!(arr, element, StringArray), + DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), + DataType::Boolean => positions!(arr, element, BooleanArray), + DataType::Float32 => positions!(arr, element, Float32Array), + DataType::Float64 => positions!(arr, element, Float64Array), + DataType::Int8 => positions!(arr, element, Int8Array), + DataType::Int16 => positions!(arr, element, Int16Array), + DataType::Int32 => positions!(arr, element, Int32Array), + DataType::Int64 => positions!(arr, element, Int64Array), + DataType::UInt8 => positions!(arr, element, UInt8Array), + DataType::UInt16 => positions!(arr, element, UInt16Array), + DataType::UInt32 => positions!(arr, element, UInt32Array), + DataType::UInt64 => positions!(arr, element, UInt64Array), data_type => { return Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." + "Array_positions is not implemented for types '{data_type:?}'." ))) } }; @@ -2585,35 +2656,6 @@ mod tests { ); } - #[test] - fn test_array_fill() { - // array_fill(4, [5]) = [4, 4, 4, 4, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ScalarValue::Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ]; - - let array = array_fill(&args) - .expect("failed to initialize function array_fill") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_fill"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 4, 4, 4, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - #[test] fn test_array_position() { // array_position([1, 2, 3, 4], 3) = 3 @@ -3002,6 +3044,55 @@ mod tests { ); } + #[test] + fn test_array_repeat() { + // array_repeat(3, 5) = [3, 3, 3, 3, 3] + let array = array_repeat(&[ + Arc::new(Int64Array::from_value(3, 1)), + Arc::new(Int64Array::from_value(5, 1)), + ]) + .expect("failed to initialize function array_repeat"); + let result = + as_list_array(&array).expect("failed to initialize function array_repeat"); + + assert_eq!(result.len(), 1); + assert_eq!( + &[3, 3, 3, 3, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + + #[test] + fn test_nested_array_repeat() { + // array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] + let element = return_array().into_array(1); + let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))]) + .expect("failed to initialize function array_repeat"); + let result = + as_list_array(&array).expect("failed to initialize function array_repeat"); + + assert_eq!(result.len(), 1); + let data = vec![ + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + ]; + let expected = ListArray::from_iter_primitive::(data); + assert_eq!( + expected, + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .clone() + ); + } #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e877faa3c12e..df76d55bfcaa 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -434,7 +434,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } - BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } @@ -450,6 +449,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayPrepend => { Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args)) } + BuiltinScalarFunction::ArrayRepeat => { + Arc::new(|args| make_scalar_function(array_expressions::array_repeat)(args)) + } BuiltinScalarFunction::ArrayRemove => { Arc::new(|args| make_scalar_function(array_expressions::array_remove)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3ef5850b546b..d118b3c1c9f6 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -557,7 +557,7 @@ enum ScalarFunction { ArrayAppend = 86; ArrayConcat = 87; ArrayDims = 88; - ArrayFill = 89; + ArrayRepeat = 89; ArrayLength = 90; ArrayNdims = 91; ArrayPosition = 92; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index df70a70f6c4f..9768dc6fdb73 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18281,7 +18281,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayAppend => "ArrayAppend", Self::ArrayConcat => "ArrayConcat", Self::ArrayDims => "ArrayDims", - Self::ArrayFill => "ArrayFill", + Self::ArrayRepeat => "ArrayRepeat", Self::ArrayLength => "ArrayLength", Self::ArrayNdims => "ArrayNdims", Self::ArrayPosition => "ArrayPosition", @@ -18404,7 +18404,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend", "ArrayConcat", "ArrayDims", - "ArrayFill", + "ArrayRepeat", "ArrayLength", "ArrayNdims", "ArrayPosition", @@ -18558,7 +18558,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), "ArrayDims" => Ok(ScalarFunction::ArrayDims), - "ArrayFill" => Ok(ScalarFunction::ArrayFill), + "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), "ArrayLength" => Ok(ScalarFunction::ArrayLength), "ArrayNdims" => Ok(ScalarFunction::ArrayNdims), "ArrayPosition" => Ok(ScalarFunction::ArrayPosition), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a5f4d27e5fad..d57455a29f47 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2291,7 +2291,7 @@ pub enum ScalarFunction { ArrayAppend = 86, ArrayConcat = 87, ArrayDims = 88, - ArrayFill = 89, + ArrayRepeat = 89, ArrayLength = 90, ArrayNdims = 91, ArrayPosition = 92, @@ -2411,7 +2411,7 @@ impl ScalarFunction { ScalarFunction::ArrayAppend => "ArrayAppend", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", - ScalarFunction::ArrayFill => "ArrayFill", + ScalarFunction::ArrayRepeat => "ArrayRepeat", ScalarFunction::ArrayLength => "ArrayLength", ScalarFunction::ArrayNdims => "ArrayNdims", ScalarFunction::ArrayPosition => "ArrayPosition", @@ -2528,7 +2528,7 @@ impl ScalarFunction { "ArrayAppend" => Some(Self::ArrayAppend), "ArrayConcat" => Some(Self::ArrayConcat), "ArrayDims" => Some(Self::ArrayDims), - "ArrayFill" => Some(Self::ArrayFill), + "ArrayRepeat" => Some(Self::ArrayRepeat), "ArrayLength" => Some(Self::ArrayLength), "ArrayNdims" => Some(Self::ArrayNdims), "ArrayPosition" => Some(Self::ArrayPosition), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index add4e78548d8..9d7dd4e49029 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,9 +37,9 @@ use datafusion_common::{ use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_fill, array_has, array_has_all, array_has_any, array_length, array_ndims, - array_position, array_positions, array_prepend, array_remove, array_remove_all, - array_remove_n, array_replace, array_replace_all, array_replace_n, array_slice, + array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, + array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, + array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, @@ -457,12 +457,12 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayElement => Self::ArrayElement, - ScalarFunction::ArrayFill => Self::ArrayFill, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, ScalarFunction::ArrayPrepend => Self::ArrayPrepend, + ScalarFunction::ArrayRepeat => Self::ArrayRepeat, ScalarFunction::ArrayRemove => Self::ArrayRemove, ScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, ScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, @@ -1245,10 +1245,6 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - ScalarFunction::ArrayFill => Ok(array_fill( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::ArrayPosition => Ok(array_position( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1258,6 +1254,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayRepeat => Ok(array_repeat( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayRemove => Ok(array_remove( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 78588ba536b0..4df96c0c5417 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1426,12 +1426,12 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, - BuiltinScalarFunction::ArrayFill => Self::ArrayFill, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend, + BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat, BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove, BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 0278f779388d..a04f43fd4b2b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -188,12 +188,12 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | -| array_fill(element, array) | Returns an array filled with copies of the given value. | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | | array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | | array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | | array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | | array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index fa1bac5d9426..dec120db18c5 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1448,7 +1448,6 @@ from_unixtime(expression) - [array_dims](#array_dims) - [array_element](#array_element) - [array_extract](#array_extract) -- [array_fill](#array_fill) - [array_indexof](#array_indexof) - [array_join](#array_join) - [array_length](#array_length) @@ -1458,6 +1457,7 @@ from_unixtime(expression) - [array_positions](#array_positions) - [array_push_back](#array_push_back) - [array_push_front](#array_push_front) +- [array_repeat](#array_repeat) - [array_remove](#array_remove) - [array_remove_n](#array_remove_n) - [array_remove_all](#array_remove_all) @@ -1482,6 +1482,7 @@ from_unixtime(expression) - [list_positions](#list_positions) - [list_push_back](#list_push_back) - [list_push_front](#list_push_front) +- [list_repeat](#list_repeat) - [list_remove](#list_remove) - [list_remove_n](#list_remove_n) - [list_remove_all](#list_remove_all) @@ -1672,6 +1673,8 @@ _Alias of [array_element](#array_element)._ Returns an array filled with copies of the given value. +DEPRECATED: use `array_repeat` instead! + ``` array_fill(element, array) ``` @@ -1848,6 +1851,40 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ +### `array_repeat` + +Returns an array containing element `count` times. + +``` +array_repeat(element, count) +``` + +#### Arguments + +- **element**: Element expression. + Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. + +#### Example + +``` +❯ select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +``` + +``` +❯ select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +``` + ### `array_remove` Removes the first element from the array equal to the given value. @@ -2165,6 +2202,10 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ +### `list_repeat` + +_Alias of [array_repeat](#array_repeat)._ + ### `list_remove` _Alias of [array_remove](#array_remove)._