From 0084c7388d1462b8027ffb3946d873487a52fe61 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Aug 2022 17:35:11 -0400 Subject: [PATCH 1/2] Reduce code duplication creating ScalarValue::List --- datafusion/common/src/scalar.rs | 76 +++++------- .../src/aggregate/array_agg_distinct.rs | 28 ++--- datafusion/proto/src/lib.rs | 116 +++++------------- datafusion/sql/src/planner.rs | 10 +- 4 files changed, 73 insertions(+), 157 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index dff97d2f981f..417738648582 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -648,6 +648,11 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } + /// Create a new nullable ScalarValue::List with the specified child_type + pub fn new_list(scalars: Option>, child_type: DataType) -> Self { + Self::List(scalars, Box::new(Field::new("item", child_type, true))) + } + /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { @@ -1506,10 +1511,7 @@ impl ScalarValue { Some(scalar_vec) } }; - ScalarValue::List( - value, - Box::new(Field::new("item", nested_type.data_type().clone(), true)), - ) + ScalarValue::new_list(value, nested_type.data_type()) } DataType::Date32 => { typed_cast!(array, index, Date32Array, Date32) @@ -1610,10 +1612,7 @@ impl ScalarValue { Some(scalar_vec) } }; - ScalarValue::List( - value, - Box::new(Field::new("item", nested_type.data_type().clone(), true)), - ) + ScalarValue::new_list(value, nested_type.data_type()) } other => { return Err(DataFusionError::NotImplemented(format!( @@ -1951,10 +1950,9 @@ impl TryFrom<&DataType> for ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - DataType::List(ref nested_type) => ScalarValue::List( - None, - Box::new(Field::new("item", nested_type.data_type().clone(), true)), - ), + DataType::List(ref nested_type) => { + ScalarValue::new_list(None, nested_type.data_type().clone()) + } DataType::Struct(fields) => { ScalarValue::Struct(None, Box::new(fields.clone())) } @@ -3124,20 +3122,12 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = ScalarValue::List( - Some(vec![s0.clone(), s1.clone()]), - Box::new(Field::new("item", s0.get_datatype(), true)), - ); + let nl0 = + ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.get_datatype()); - let nl1 = ScalarValue::List( - Some(vec![s2]), - Box::new(Field::new("item", s0.get_datatype(), true)), - ); + let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.get_datatype()); - let nl2 = ScalarValue::List( - Some(vec![s1]), - Box::new(Field::new("item", s0.get_datatype(), true)), - ); + let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); let array = array.as_any().downcast_ref::().unwrap(); @@ -3263,50 +3253,42 @@ mod tests { #[test] fn test_nested_lists() { // Define inner list scalars - let l1 = ScalarValue::List( + let l1 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let l2 = ScalarValue::List( + let l2 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(6i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let l3 = ScalarValue::List( - Some(vec![ScalarValue::List( + let l3 = ScalarValue::new_list( + Some(vec![ScalarValue::new_list( Some(vec![ScalarValue::from(9i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, )]), Box::new(Field::new( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index f9899379d2c9..ff1c9053613a 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -120,9 +120,9 @@ impl DistinctArrayAggAccumulator { impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { - Ok(vec![AggregateState::Scalar(ScalarValue::List( + Ok(vec![AggregateState::Scalar(ScalarValue::new_list( Some(self.values.clone().into_iter().collect()), - Box::new(Field::new("item", self.datatype.clone(), true)), + self.datatype.clone(), ))]) } @@ -151,9 +151,9 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::List( + Ok(ScalarValue::new_list( Some(self.values.clone().into_iter().collect()), - Box::new(Field::new("item", self.datatype.clone(), true)), + self.datatype.clone(), )) } } @@ -206,7 +206,7 @@ mod tests { fn distinct_array_agg_i32() -> Result<()> { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - let out = ScalarValue::List( + let out = ScalarValue::new_list( Some(vec![ ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2)), @@ -214,7 +214,7 @@ mod tests { ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), ]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ); check_distinct_array_agg(col, out, DataType::Int32) @@ -223,26 +223,22 @@ mod tests { #[test] fn distinct_array_agg_nested() -> Result<()> { // [[1, 2, 3], [4, 5]] - let l1 = ScalarValue::List( + let l1 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // [[6], [7, 8]] diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index c69723442790..eecca1b6ad59 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -289,43 +289,27 @@ mod roundtrip_tests { fn scalar_values_error_serialization() { let should_fail_on_seralize: Vec = vec![ // Should fail due to inconsistent types - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::Int16(None), ScalarValue::Float32(Some(32.0)), ]), - new_box_field( - "item", - DataType::List(new_box_field("item", DataType::Int16, true)), - true, - ), + DataType::List(new_box_field("item", DataType::Int16, true)), ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::Float32(None), ScalarValue::Float32(Some(32.0)), ]), - new_box_field( - "item", - DataType::List(new_box_field("item", DataType::Int16, true)), - true, - ), + DataType::List(new_box_field("item", DataType::Int16, true)), ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( None, - new_box_field( - "item", - DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - )), - true, - ), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -333,38 +317,22 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - new_box_field( - "item", - DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - )), - true, - ), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), - ScalarValue::List( + ScalarValue::new_list( None, - new_box_field( - "item", - DataType::List(new_box_field( - "lists are typed inconsistently", - DataType::Int16, - true, - )), + DataType::List(new_box_field( + "lists are typed inconsistently", + DataType::Int16, true, - ), + )), ), ]), - new_box_field( - "item", - DataType::List(new_box_field( - "level1", - DataType::List(new_box_field("level2", DataType::Float32, true)), - true, - )), + DataType::List(new_box_field( + "level1", + DataType::List(new_box_field("level2", DataType::Float32, true)), true, - ), + )), ), ]; @@ -397,7 +365,7 @@ mod roundtrip_tests { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(None, new_box_field("item", DataType::Boolean, true)), + ScalarValue::new_list(None, DataType::Boolean), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -453,7 +421,7 @@ mod roundtrip_tests { ScalarValue::TimestampSecond(Some(i64::MAX), None), ScalarValue::TimestampSecond(Some(0), Some("UTC".to_string())), ScalarValue::TimestampSecond(None, None), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -461,27 +429,15 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - new_box_field( - "item", - DataType::List(new_box_field("level1", DataType::Float32, true)), - true, - ), + DataType::List(new_box_field("level1", DataType::Float32, true)), ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( None, - new_box_field( - "item", - DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - )), - true, - ), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -489,26 +445,14 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - new_box_field( - "item", - DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - )), - true, - ), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), ]), - new_box_field( - "item", - DataType::List(new_box_field( - "level1", - DataType::List(new_box_field("level2", DataType::Float32, true)), - true, - )), + DataType::List(new_box_field( + "level1", + DataType::List(new_box_field("level2", DataType::Float32, true)), true, - ), + )), ), ]; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 28c82f80246f..4c46790235e7 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2410,10 +2410,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { values.iter().map(|e| e.get_datatype()).collect(); if data_types.is_empty() { - Ok(Expr::Literal(ScalarValue::List( - None, - Box::new(Field::new("item", DataType::Utf8, true)), - ))) + Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) } else if data_types.len() > 1 { Err(DataFusionError::NotImplemented(format!( "Arrays with different types are not supported: {:?}", @@ -2422,10 +2419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let data_type = values[0].get_datatype(); - Ok(Expr::Literal(ScalarValue::List( - Some(values), - Box::new(Field::new("item", data_type, true)), - ))) + Ok(lit(ScalarValue::new_list(Some(values), data_type))) } } } From 60084cbd8351f7d115f8750fc8cd94a3c99d9f16 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Aug 2022 17:44:26 -0400 Subject: [PATCH 2/2] clean more --- datafusion/common/src/scalar.rs | 10 +--- .../physical-expr/src/aggregate/array_agg.rs | 60 +++++++------------ .../src/aggregate/array_agg_distinct.rs | 36 ++++------- .../src/aggregate/count_distinct.rs | 5 +- .../src/aggregate/sum_distinct.rs | 4 +- .../physical-expr/src/aggregate/tdigest.rs | 7 +-- 6 files changed, 42 insertions(+), 80 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 417738648582..531738a4909f 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1511,7 +1511,7 @@ impl ScalarValue { Some(scalar_vec) } }; - ScalarValue::new_list(value, nested_type.data_type()) + ScalarValue::new_list(value, nested_type.data_type().clone()) } DataType::Date32 => { typed_cast!(array, index, Date32Array, Date32) @@ -1612,7 +1612,7 @@ impl ScalarValue { Some(scalar_vec) } }; - ScalarValue::new_list(value, nested_type.data_type()) + ScalarValue::new_list(value, nested_type.data_type().clone()) } other => { return Err(DataFusionError::NotImplemented(format!( @@ -3290,11 +3290,7 @@ mod tests { Some(vec![ScalarValue::from(9i32)]), DataType::Int32, )]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index e7fd0937cc87..160e4477b102 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -148,9 +148,9 @@ impl Accumulator for ArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::List( + Ok(ScalarValue::new_list( Some(self.values.clone()), - Box::new(Field::new("item", self.datatype.clone(), true)), + self.datatype.clone(), )) } } @@ -171,7 +171,7 @@ mod tests { fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - let list = ScalarValue::List( + let list = ScalarValue::new_list( Some(vec![ ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2)), @@ -179,7 +179,7 @@ mod tests { ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), ]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ); generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) @@ -187,65 +187,49 @@ mod tests { #[test] fn array_agg_nested() -> Result<()> { - let l1 = ScalarValue::List( + let l1 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let l2 = ScalarValue::List( + let l2 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(6i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let l3 = ScalarValue::List( - Some(vec![ScalarValue::List( + let l3 = ScalarValue::new_list( + Some(vec![ScalarValue::new_list( Some(vec![ScalarValue::from(9i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, )]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let list = ScalarValue::List( + let list = ScalarValue::new_list( Some(vec![l1.clone(), l2.clone(), l3.clone()]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index ff1c9053613a..a0ef021b807c 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -242,44 +242,32 @@ mod tests { ); // [[6], [7, 8]] - let l2 = ScalarValue::List( + let l2 = ScalarValue::new_list( Some(vec![ - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(6i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), - ScalarValue::List( + ScalarValue::new_list( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, ), ]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // [[9]] - let l3 = ScalarValue::List( - Some(vec![ScalarValue::List( + let l3 = ScalarValue::new_list( + Some(vec![ScalarValue::new_list( Some(vec![ScalarValue::from(9i32)]), - Box::new(Field::new("item", DataType::Int32, true)), + DataType::Int32, )]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); - let list = ScalarValue::List( + let list = ScalarValue::new_list( Some(vec![l1.clone(), l2.clone(), l3.clone()]), - Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - )), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // Duplicate l1 in the input array and check that it is deduped in the output. diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 6060ddb4dc99..6dcd21d92e50 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -183,10 +183,7 @@ impl Accumulator for DistinctCountAccumulator { .iter() .map(|state_data_type| { let values = Box::new(Vec::new()); - ScalarValue::List( - Some(*values), - Box::new(Field::new("item", state_data_type.clone(), true)), - ) + ScalarValue::new_list(Some(*values), state_data_type.clone()) }) .collect::>(); diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 96ba81834959..d2ab46bdbfd3 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -136,9 +136,9 @@ impl Accumulator for DistinctSumAccumulator { self.hash_values .iter() .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![AggregateState::Scalar(ScalarValue::List( + vec![AggregateState::Scalar(ScalarValue::new_list( Some(distinct_values), - Box::new(Field::new("item", self.data_type.clone(), true)), + self.data_type.clone(), ))] }; Ok(state_out) diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 114eb185cf93..fa937d3e159b 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -27,7 +27,7 @@ //! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -624,10 +624,7 @@ impl TDigest { ScalarValue::Float64(Some(self.count.into_inner())), ScalarValue::Float64(Some(self.max.into_inner())), ScalarValue::Float64(Some(self.min.into_inner())), - ScalarValue::List( - Some(centroids), - Box::new(Field::new("item", DataType::Float64, true)), - ), + ScalarValue::new_list(Some(centroids), DataType::Float64), ] }