diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index aad021610fcb..e377373e83f1 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2018,8 +2018,21 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( arg, @@ -2051,35 +2064,61 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } } - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - - match arr.data_type() { - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { - let list_array = arr.as_list::(); - for (arr, &delimiter) in list_array.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); } else { - res.push(None); + res.push(Some(s)); } + } else { + res.push(None); } } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); // delimiter length is 1 assert_eq!(delimiters.len(), 1); let delimiter = delimiters[0].unwrap(); @@ -2098,10 +2137,11 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } else { res.push(Some(s)); } + StringArray::from(res) } - } + }; - Ok(Arc::new(StringArray::from(res))) + Ok(Arc::new(string_arr)) } /// Cardinality SQL function diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index a3b2c8cdf1e9..b07e72a689a9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3201,30 +3201,55 @@ select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_join scalar function #5 (function alias `array_to_string`) query TTT select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select array_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # list_join scalar function #6 (function alias `list_join`) query TTT select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #1 query TTT select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); ---- h,l,o 1-3-5 2|3 +query TTT +select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #2 query TTT select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); ---- h,-,-,-,o nil-2-nil-4-5 1|0|3 +query TTT +select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'LargeList(Utf8)'), ',', '-'), array_to_string(arrow_cast(make_array(NULL, 2, NULL, 4, 5), 'LargeList(Int64)'), '-', 'nil'), array_to_string(arrow_cast(make_array(1.0, NULL, 3.0), 'LargeList(Float64)'), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + # array_to_string with columns #1 # For reference @@ -3251,6 +3276,18 @@ NULL 51^52^54^55^56^57^58^59^60 NULL +query T +select array_to_string(column1, column4) from large_arrays_values; +---- +2,3,4,5,6,7,8,9,10 +11.12.13.14.15.16.17.18.20 +21-22-23-25-26-27-28-29-30 +31ok32ok33ok34ok35ok37ok38ok39ok40 +NULL +41$42$43$44$45$46$47$48$49$50 +51^52^54^55^56^57^58^59^60 +NULL + query TT select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from arrays_values; ---- @@ -3263,6 +3300,18 @@ NULL 1/2/3 51_52_54_55_56_57_58_59_60 1/2/3 61_62_63_64_65_66_67_68_69_70 1/2/3 +query TT +select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from large_arrays_values; +---- +2_3_4_5_6_7_8_9_10 1/2/3 +11_12_13_14_15_16_17_18_20 1/2/3 +21_22_23_25_26_27_28_29_30 1/2/3 +31_32_33_34_35_37_38_39_40 1/2/3 +NULL 1/2/3 +41_42_43_44_45_46_47_48_49_50 1/2/3 +51_52_54_55_56_57_58_59_60 1/2/3 +61_62_63_64_65_66_67_68_69_70 1/2/3 + query TT select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from arrays_values; ---- @@ -3275,6 +3324,18 @@ NULL 1.2.3 51_52_*_54_55_56_57_58_59_60 1.2.3 61_62_63_64_65_66_67_68_69_70 1.2.3 +query TT +select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from large_arrays_values; +---- +*_2_3_4_5_6_7_8_9_10 1.2.3 +11_12_13_14_15_16_17_18_*_20 1.2.3 +21_22_23_*_25_26_27_28_29_30 1.2.3 +31_32_33_34_35_*_37_38_39_40 1.2.3 +NULL 1.2.3 +41_42_43_44_45_46_47_48_49_50 1.2.3 +51_52_*_54_55_56_57_58_59_60 1.2.3 +61_62_63_64_65_66_67_68_69_70 1.2.3 + ## cardinality # cardinality scalar function