From 056b4852c1331b36ca0502c37432e36ea62aa8c1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 2 Dec 2022 09:19:38 -0500 Subject: [PATCH 1/2] Fix panic in median "AggregateState is not a scalar aggregate" --- datafusion/common/src/scalar.rs | 6 +- datafusion/core/tests/sql/aggregates.rs | 84 +++++++ .../physical-expr/src/aggregate/median.rs | 237 ++++++++---------- 3 files changed, 189 insertions(+), 138 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d7c5df0656ee..9c1df331c768 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -720,7 +720,7 @@ impl std::hash::Hash for ScalarValue { /// dictionary array #[inline] fn get_dict_value( - array: &ArrayRef, + array: &dyn Array, index: usize, ) -> (&ArrayRef, Option) { let dict_array = as_dictionary_array::(array).unwrap(); @@ -1962,7 +1962,7 @@ impl ScalarValue { } fn get_decimal_value_from_array( - array: &ArrayRef, + array: &dyn Array, index: usize, precision: u8, scale: i8, @@ -1977,7 +1977,7 @@ impl ScalarValue { } /// Converts a value in `array` at `index` into a ScalarValue - pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value if !array.is_valid(index) { return array.data_type().try_into(); diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 8b57a69fa61d..9fbed3a81c0e 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -436,6 +436,90 @@ async fn median_test( Ok(()) } +#[tokio::test] +// test case for https://github.com/apache/arrow-datafusion/issues/3105 +// has an intermediate grouping there +async fn median_multi() -> Result<()> { + let ctx = SessionContext::new(); + ctx.sql("create table cpu (host string, usage float) as select * from (values ('host0', 90.1), ('host1', 90.2), ('host1', 90.4));") + .await? + .collect() + .await?; + + let sql = "select host, median(usage) from cpu group by host;"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+-------------------+", + "| host | MEDIAN(cpu.usage) |", + "+-------+-------------------+", + "| host0 | 90.1 |", + "| host1 | 90.3 |", + "+-------+-------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + let sql = "select median(usage) from cpu;"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------------------+", + "| MEDIAN(cpu.usage) |", + "+-------------------+", + "| 90.2 |", + "+-------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn median_multi_odd() -> Result<()> { + let ctx = SessionContext::new(); + // data is not sorted and has an odd number of values per group + ctx.sql("create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host0', 87.9), ('host1', 89.3) );") + .await.unwrap() + .collect() + .await.unwrap(); + + let sql = "select host, median(usage) from cpu group by host;"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+-------------------+", + "| host | MEDIAN(cpu.usage) |", + "+-------+-------------------+", + "| host0 | 90.2 |", + "| host1 | 90.1 |", + "+-------+-------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn median_multi_even() -> Result<()> { + let ctx = SessionContext::new(); + // data is not sorted and has an even number of values per group + ctx.sql("create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3) );") + .await.unwrap() + .collect() + .await.unwrap(); + + let sql = "select host, median(usage) from cpu group by host;"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+-------------------+", + "| host | MEDIAN(cpu.usage) |", + "+-------+-------------------+", + "| host0 | 90.35 |", + "| host1 | 90.25 |", + "+-------+-------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let ctx = SessionContext::new(); diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 64d6fa7b4ac8..5ac9a46f28bc 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -19,13 +19,9 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; -use arrow::compute::sort; -use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; -use datafusion_common::cast::as_primitive_array; +use arrow::array::{Array, ArrayRef, UInt32Array}; +use arrow::compute::sort_to_indices; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, AggregateState}; use std::any::Any; @@ -74,9 +70,13 @@ impl AggregateExpr for Median { } fn state_fields(&self) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", self.data_type.clone(), true); + let data_type = DataType::List(Box::new(field)); + Ok(vec![Field::new( &format_state_name(&self.name, "median"), - self.data_type.clone(), + data_type, true, )]) } @@ -91,157 +91,124 @@ impl AggregateExpr for Median { } #[derive(Debug)] +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of those scalars struct MedianAccumulator { data_type: DataType, - all_values: Vec, -} - -macro_rules! median { - ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{ - let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?; - if combined.is_empty() { - return Ok(ScalarValue::Null); - } - let sorted = sort(&combined, None)?; - let array = as_primitive_array::<$TY>(&sorted)?; - let len = sorted.len(); - let mid = len / 2; - if len % 2 == 0 { - Ok(ScalarValue::$SCALAR_TY(Some( - (array.value(mid - 1) + array.value(mid)) / $TWO, - ))) - } else { - Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid)))) - } - }}; + all_values: Vec, } impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let mut vec: Vec = self - .all_values - .iter() - .map(|v| AggregateState::Array(v.clone())) - .collect(); - if vec.is_empty() { - match self.data_type { - DataType::UInt8 => vec.push(empty_array::()), - DataType::UInt16 => vec.push(empty_array::()), - DataType::UInt32 => vec.push(empty_array::()), - DataType::UInt64 => vec.push(empty_array::()), - DataType::Int8 => vec.push(empty_array::()), - DataType::Int16 => vec.push(empty_array::()), - DataType::Int32 => vec.push(empty_array::()), - DataType::Int64 => vec.push(empty_array::()), - DataType::Float32 => vec.push(empty_array::()), - DataType::Float64 => vec.push(empty_array::()), - _ => { - return Err(DataFusionError::Execution( - "unsupported data type for median".to_string(), - )) - } - } - } - Ok(vec) + let state = + ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone()); + Ok(vec![AggregateState::Scalar(state)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let x = values[0].clone(); - self.all_values.extend_from_slice(&[x]); - Ok(()) - } + assert_eq!(values.len(), 1); + let array = &values[0]; - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for array in states { - self.all_values.extend_from_slice(&[array.clone()]); + self.all_values.reserve(self.all_values.len() + array.len()); + for index in 0..array.len() { + self.all_values + .push(ScalarValue::try_from_array(array, index)?); } + Ok(()) } - fn evaluate(&self) -> Result { - match self.all_values[0].data_type() { - DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 2), - DataType::Int16 => median!(self, arrow::datatypes::Int16Type, Int16, 2), - DataType::Int32 => median!(self, arrow::datatypes::Int32Type, Int32, 2), - DataType::Int64 => median!(self, arrow::datatypes::Int64Type, Int64, 2), - DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, UInt8, 2), - DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, UInt16, 2), - DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, UInt32, 2), - DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, UInt64, 2), - DataType::Float32 => { - median!(self, arrow::datatypes::Float32Type, Float32, 2_f32) - } - DataType::Float64 => { - median!(self, arrow::datatypes::Float64Type, Float64, 2_f64) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + assert_eq!(states.len(), 1); + + let array = &states[0]; + for index in 0..array.len() { + match ScalarValue::try_from_array(array, index)? { + ScalarValue::List(Some(mut values), _) => { + self.all_values.append(&mut values); + } + ScalarValue::List(None, _) => {} // skip empty state + v => { + return Err(DataFusionError::Internal(format!( + "unexpected state in median. Expected DataType::List, got {:?}", + v + ))) + } } - _ => Err(DataFusionError::Execution( - "unsupported data type for median".to_string(), - )), } + Ok(()) } - fn size(&self) -> usize { - // TODO(crepererum): `DataType` is NOT fixed size, add `DataType::size` method to arrow (https://github.com/apache/arrow-rs/issues/3147) - std::mem::align_of_val(self) - + (std::mem::size_of::() * self.all_values.capacity()) - + self - .all_values + fn evaluate(&self) -> Result { + // Create an array of all the non null values and find the + // sorted indexes + let array = ScalarValue::iter_to_array( + self.all_values .iter() - .map(|array_ref| { - std::mem::size_of_val(array_ref.as_ref()) - + array_ref.get_array_memory_size() - }) - .sum::() - } -} + // ignore null values + .filter(|v| !v.is_null()) + .cloned(), + )?; -/// Create an empty array -fn empty_array() -> AggregateState { - AggregateState::Array(Arc::new(PrimitiveBuilder::::with_capacity(0).finish())) -} + // find the mid point + let len = array.len(); + let mid = len / 2; -/// Combine all non-null values from provided arrays into a single array -fn combine_arrays(arrays: &[ArrayRef]) -> Result { - let len = arrays.iter().map(|a| a.len() - a.null_count()).sum(); - let mut builder: PrimitiveBuilder = PrimitiveBuilder::with_capacity(len); - for array in arrays { - let array = as_primitive_array::(array)?; - for i in 0..array.len() { - if !array.is_null(i) { - builder.append_value(array.value(i)); + // only sort up to the top size/2 elements + let limit = Some(mid + 1); + let options = None; + let indices = sort_to_indices(&array, options, limit)?; + + // pick the relevant indices in the original arrays + let result = if len >= 2 && len % 2 == 0 { + // even number of values, average the two mid points + let s1 = scalar_at_index(&array, &indices, mid - 1)?; + let s2 = scalar_at_index(&array, &indices, mid)?; + match s1.add(s2)? { + ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)), + ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)), + ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)), + ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)), + ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)), + ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)), + ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)), + ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)), + ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)), + ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)), + v => { + return Err(DataFusionError::Internal(format!( + "Unsupported type in MedianAccumulator: {:?}", + v + ))) + } } - } - } - Ok(Arc::new(builder.finish())) -} + } else { + // odd number of values, pick that one + scalar_at_index(&array, &indices, mid)? + }; -#[cfg(test)] -mod test { - use crate::aggregate::median::combine_arrays; - use arrow::array::{Int32Array, UInt32Array}; - use arrow::datatypes::{Int32Type, UInt32Type}; - use datafusion_common::Result; - use std::sync::Arc; - - #[test] - fn combine_i32_array() -> Result<()> { - let a = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b = combine_arrays::(&[a.clone(), a])?; - assert_eq!( - "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", - format!("{:?}", b) - ); - Ok(()) + Ok(result) } - #[test] - fn combine_u32_array() -> Result<()> { - let a = Arc::new(UInt32Array::from(vec![1, 2, 3])); - let b = combine_arrays::(&[a.clone(), a])?; - assert_eq!( - "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", - format!("{:?}", b) - ); - Ok(()) + fn size(&self) -> usize { + // TODO(crepererum): `DataType` is NOT fixed size, add + // `DataType::size` method to arrow + // (https://github.com/apache/arrow-rs/issues/3147) + std::mem::align_of_val(self) + ScalarValue::size_of_vec(&self.all_values) } } + +/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` +fn scalar_at_index( + array: &dyn Array, + indices: &UInt32Array, + indicies_index: usize, +) -> Result { + let array_index = indices + .value(indicies_index) + .try_into() + .expect("Convert uint32 to usize"); + ScalarValue::try_from_array(array, array_index) +} From 5e72719205bb5efffac08b038b0032f0ba2bedfc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 9 Dec 2022 05:50:40 -0500 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- datafusion/physical-expr/src/aggregate/median.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 5ac9a46f28bc..4276b45b21d1 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -111,6 +111,7 @@ impl Accumulator for MedianAccumulator { assert_eq!(values.len(), 1); let array = &values[0]; + assert!(matches!(array.data_type(), DataType::List(_))); self.all_values.reserve(self.all_values.len() + array.len()); for index in 0..array.len() { self.all_values @@ -124,6 +125,7 @@ impl Accumulator for MedianAccumulator { assert_eq!(states.len(), 1); let array = &states[0]; + assert!(matches!(array.data_type(), DataType::List(_))); for index in 0..array.len() { match ScalarValue::try_from_array(array, index)? { ScalarValue::List(Some(mut values), _) => {