Skip to content

Commit

Permalink
Remove AggregateState wrapper (#4582)
Browse files Browse the repository at this point in the history
* Remove AggregateState wrapper

* Remove more unwrap

* Fix logical conflicts

* Remove unecessary array
  • Loading branch information
alamb authored Dec 14, 2022
1 parent 84d3ae8 commit 5d424ef
Show file tree
Hide file tree
Showing 25 changed files with 97 additions and 164 deletions.
7 changes: 3 additions & 4 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use datafusion::arrow::{
array::ArrayRef, array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
};
use datafusion::from_slice::FromSlice;
use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use datafusion_common::cast::as_float64_array;
Expand Down Expand Up @@ -108,10 +107,10 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<AggregateState>> {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
AggregateState::Scalar(ScalarValue::from(self.prod)),
AggregateState::Scalar(ScalarValue::from(self.n)),
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
])
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/aggregates/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ fn create_batch_from_map(
accumulators.group_states.iter().map(|group_state| {
group_state.accumulator_set[x]
.state()
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
.map(|x| x[y].clone())
.expect("unexpected accumulator state in hash aggregate")
}),
)?;
Expand Down
8 changes: 3 additions & 5 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_primitive_array;
use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility};
use datafusion_expr::{create_udaf, Accumulator, Volatility};
use futures::FutureExt;

fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
Expand All @@ -193,10 +193,8 @@ mod tests {
struct MyCount(i64);

impl Accumulator for MyCount {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
self.0,
)))])
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Int64(Some(self.0))])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::{
physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function},
};
use datafusion_common::{cast::as_int32_array, ScalarValue};
use datafusion_expr::{create_udaf, Accumulator, AggregateState, LogicalPlanBuilder};
use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder};

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -175,7 +175,7 @@ fn udaf_as_window_func() -> Result<()> {
struct MyAccumulator;

impl Accumulator for MyAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
fn state(&self) -> Result<Vec<ScalarValue>> {
unimplemented!()
}

Expand Down
9 changes: 2 additions & 7 deletions datafusion/core/tests/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use datafusion::{
},
assert_batches_eq,
error::Result,
logical_expr::AggregateState,
logical_expr::{
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
Expand Down Expand Up @@ -210,12 +209,8 @@ impl FirstSelector {
}

impl Accumulator for FirstSelector {
fn state(&self) -> Result<Vec<AggregateState>> {
let state = self
.to_state()
.into_iter()
.map(AggregateState::Scalar)
.collect::<Vec<_>>();
fn state(&self) -> Result<Vec<ScalarValue>> {
let state = self.to_state().into_iter().collect::<Vec<_>>();

Ok(state)
}
Expand Down
51 changes: 11 additions & 40 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@ pub trait Accumulator: Send + Sync + Debug {
/// accumulator (that ran on different partitions, for
/// example).
///
/// The state can be a different type than the output of the
/// [`Accumulator`]
/// The state can be and often is a different type than the output
/// type of the [`Accumulator`].
///
/// See [`merge_batch`] for more details on the merging process.
///
/// For example, in the case of an average, for which we track `sum` and `n`,
/// this function should return a vector of two values, sum and n.
fn state(&self) -> Result<Vec<AggregateState>>;
/// Some accumulators can return multiple values for their
/// intermediate states. For example average, tracks `sum` and
/// `n`, and this function should return
/// a vector of two values, sum and n.
///
/// `ScalarValue::List` can also be used to pass multiple values
/// if the number of intermediate values is not known at planning
/// time (e.g. median)
fn state(&self) -> Result<Vec<ScalarValue>>;

/// Updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
Expand Down Expand Up @@ -80,38 +86,3 @@ pub trait Accumulator: Send + Sync + Debug {
/// not the `len`
fn size(&self) -> usize;
}

/// Representation of internal accumulator state. Accumulators can potentially have a mix of
/// scalar and array values. It may be desirable to add custom aggregator states here as well
/// in the future (perhaps `Custom(Box<dyn Any>)`?).
#[derive(Debug)]
pub enum AggregateState {
/// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
/// values around
Scalar(ScalarValue),
/// Arrays can be used instead of `ScalarValue::List` and could potentially have better
/// performance with large data sets, although this has not been verified. It also allows
/// for use of arrow kernels with less overhead.
Array(ArrayRef),
}

impl AggregateState {
/// Access the aggregate state as a scalar value. An error will occur if the
/// state is not a scalar value.
pub fn as_scalar(&self) -> Result<&ScalarValue> {
match &self {
Self::Scalar(v) => Ok(v),
_ => Err(DataFusionError::Internal(
"AggregateState is not a scalar aggregate".to_string(),
)),
}
}

/// Access the aggregate state as an array value.
pub fn to_array(&self) -> ArrayRef {
match &self {
Self::Scalar(v) => v.to_array(),
Self::Array(array) => array.clone(),
}
}
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub mod utils;
pub mod window_frame;
pub mod window_function;

pub use accumulator::{Accumulator, AggregateState};
pub use accumulator::Accumulator;
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::ColumnarValue;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/aggregate/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow::datatypes::{
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::convert::TryFrom;
use std::convert::TryInto;
Expand Down Expand Up @@ -231,8 +231,8 @@ macro_rules! default_accumulator_impl {
Ok(())
}

fn state(&self) -> Result<Vec<AggregateState>> {
let value = AggregateState::Scalar(ScalarValue::from(&self.hll));
fn state(&self) -> Result<Vec<ScalarValue>> {
let value = ScalarValue::from(&self.hll);
Ok(vec![value])
}

Expand Down
11 changes: 3 additions & 8 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::{
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use std::{any::Any, iter, sync::Arc};

/// APPROX_PERCENTILE_CONT aggregate expression
Expand Down Expand Up @@ -357,13 +357,8 @@ impl ApproxPercentileAccumulator {
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(self
.digest
.to_scalar_state()
.into_iter()
.map(AggregateState::Scalar)
.collect())
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state().into_iter().collect())
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::{

use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;

use std::{any::Any, sync::Arc};

Expand Down Expand Up @@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator {
}

impl Accumulator for ApproxPercentileWithWeightAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
fn state(&self) -> Result<Vec<ScalarValue>> {
self.approx_percentile_cont_accumulator.state()
}

Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator {
})
}

fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![AggregateState::Scalar(self.evaluate()?)])
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;

/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
#[derive(Debug)]
Expand Down Expand Up @@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator {
}

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![AggregateState::Scalar(ScalarValue::new_list(
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::new_list(
Some(self.values.clone().into_iter().collect()),
self.datatype.clone(),
))])
)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Expand Down
9 changes: 3 additions & 6 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow::{
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;

/// AVG aggregate expression
Expand Down Expand Up @@ -150,11 +150,8 @@ impl AvgAccumulator {
}

impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
AggregateState::Scalar(ScalarValue::from(self.count)),
AggregateState::Scalar(self.sum.clone()),
])
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/physical-expr/src/aggregate/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr};
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -133,14 +133,14 @@ impl CorrelationAccumulator {
}

impl Accumulator for CorrelationAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
AggregateState::Scalar(ScalarValue::from(self.covar.get_count())),
AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())),
AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())),
AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())),
AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())),
AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())),
ScalarValue::from(self.covar.get_count()),
ScalarValue::from(self.covar.get_mean1()),
ScalarValue::from(self.stddev1.get_m2()),
ScalarValue::from(self.covar.get_mean2()),
ScalarValue::from(self.stddev2.get_m2()),
ScalarValue::from(self.covar.get_algo_const()),
])
}

Expand Down
8 changes: 3 additions & 5 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::datatypes::DataType;
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;

use crate::expressions::format_state_name;
Expand Down Expand Up @@ -119,10 +119,8 @@ impl CountAccumulator {
}

impl Accumulator for CountAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
self.count,
)))])
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Int64(Some(self.count))])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Expand Down
Loading

0 comments on commit 5d424ef

Please sign in to comment.