Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve avg/sum Aggregator performance for Decimal #5866

Merged
merged 6 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,12 @@ mod tests {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
44 changes: 40 additions & 4 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::expressions::{Avg, CastExpr, Column, Sum};
use datafusion_physical_expr::{
expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr,
};
use std::any::Any;
use std::collections::HashMap;

use arrow::compute::DEFAULT_CAST_OPTIONS;
use std::sync::Arc;

mod no_grouping;
Expand Down Expand Up @@ -554,9 +555,44 @@ fn aggregate_expressions(
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
AggregateMode::Partial => Ok(aggr_expr
.iter()
.map(|agg| {
let pre_cast_type = if let Some(Sum {
data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Sum>()
{
if *pre_cast_to_sum_type {
Some(data_type.clone())
} else {
None
}
} else if let Some(Avg {
sum_data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Avg>()
{
if *pre_cast_to_sum_type {
Some(sum_data_type.clone())
} else {
None
}
} else {
None
};
agg.expressions()
.into_iter()
.map(|expr| {
pre_cast_type.clone().map_or(expr.clone(), |cast_type| {
Arc::new(CastExpr::new(expr, cast_type, DEFAULT_CAST_OPTIONS))
})
})
.collect::<Vec<_>>()
})
.collect()),
// in this mode, we build the merge expressions of the aggregation
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let mut col_idx_base = col_idx_base;
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/tests/sql/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ async fn simple_udaf() -> Result<()> {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
13 changes: 13 additions & 0 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ pub fn return_type(
}
}

/// Returns the internal sum datatype of the avg aggregate function.
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
&fun,
input_expr_types,
&signature(&fun),
)?;
avg_sum_type(&coerced_data_types[0])
}

/// the signatures supported by the function `fun`.
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
Expand Down
15 changes: 15 additions & 0 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,21 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
}
}

/// internal sum type of an average
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal128(precision, scale) => {
// in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"AVG does not support {other:?}"
))),
}
}

pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
Expand Down
15 changes: 12 additions & 3 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,12 @@ mod test {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF {
Expand All @@ -887,8 +892,12 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
});
let my_avg = AggregateUDF::new(
"MY_AVG",
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
Expand Down
72 changes: 52 additions & 20 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::aggregate::row_accumulator::{
};
use crate::aggregate::sum;
use crate::aggregate::sum::sum_batch;
use crate::aggregate::utils::calculate_result_decimal_for_avg;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
Expand All @@ -34,6 +35,7 @@ use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
use arrow_array::Array;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
Expand All @@ -44,25 +46,44 @@ use datafusion_row::accessor::RowAccessor;
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
pub sum_data_type: DataType,
rt_data_type: DataType,
pub pre_cast_to_sum_type: bool,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
sum_data_type: DataType,
) -> Self {
Self::new_with_pre_cast(expr, name, sum_data_type.clone(), sum_data_type, false)
}

pub fn new_with_pre_cast(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
sum_data_type: DataType,
rt_data_type: DataType,
cast_to_sum_type: bool,
) -> Self {
// the internal sum data type of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
sum_data_type,
DataType::Float64 | DataType::Decimal128(_, _)
));
// the result of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
data_type,
rt_data_type,
DataType::Float64 | DataType::Decimal128(_, _)
));
Self {
name: name.into(),
expr,
data_type,
sum_data_type,
rt_data_type,
pre_cast_to_sum_type: cast_to_sum_type,
}
}
}
Expand All @@ -74,13 +95,14 @@ impl AggregateExpr for Avg {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
Ok(Field::new(&self.name, self.rt_data_type.clone(), true))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64 or decimal
&self.data_type,
&self.sum_data_type,
&self.rt_data_type,
)?))
}

Expand All @@ -93,7 +115,7 @@ impl AggregateExpr for Avg {
),
Field::new(
format_state_name(&self.name, "sum"),
self.data_type.clone(),
self.sum_data_type.clone(),
true,
),
])
Expand All @@ -108,7 +130,7 @@ impl AggregateExpr for Avg {
}

fn row_accumulator_supported(&self) -> bool {
is_row_accumulator_support_dtype(&self.data_type)
is_row_accumulator_support_dtype(&self.sum_data_type)
}

fn supports_bounded_execution(&self) -> bool {
Expand All @@ -121,7 +143,7 @@ impl AggregateExpr for Avg {
) -> Result<Box<dyn RowAccumulator>> {
Ok(Box::new(AvgRowAccumulator::new(
start_index,
self.data_type.clone(),
self.sum_data_type.clone(),
)))
}

Expand All @@ -130,7 +152,10 @@ impl AggregateExpr for Avg {
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
Ok(Box::new(AvgAccumulator::try_new(
&self.sum_data_type,
&self.rt_data_type,
)?))
}
}

Expand All @@ -139,14 +164,18 @@ impl AggregateExpr for Avg {
pub struct AvgAccumulator {
// sum is used for null
sum: ScalarValue,
sum_data_type: DataType,
return_data_type: DataType,
count: u64,
}

impl AvgAccumulator {
/// Creates a new `AvgAccumulator`
pub fn try_new(datatype: &DataType) -> Result<Self> {
pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result<Self> {
Ok(Self {
sum: ScalarValue::try_from(datatype)?,
sum_data_type: datatype.clone(),
return_data_type: return_data_type.clone(),
count: 0,
})
}
Expand All @@ -163,14 +192,14 @@ impl Accumulator for AvgAccumulator {
self.count += (values.len() - values.data().null_count()) as u64;
self.sum = self
.sum
.add(&sum::sum_batch(values, &self.sum.get_datatype())?)?;
.add(&sum::sum_batch(values, &self.sum_data_type)?)?;
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
self.count -= (values.len() - values.data().null_count()) as u64;
let delta = sum_batch(values, &self.sum.get_datatype())?;
let delta = sum_batch(values, &self.sum_data_type)?;
self.sum = self.sum.sub(&delta)?;
Ok(())
}
Expand All @@ -183,7 +212,7 @@ impl Accumulator for AvgAccumulator {
// sums are summed
self.sum = self
.sum
.add(&sum::sum_batch(&states[1], &self.sum.get_datatype())?)?;
.add(&sum::sum_batch(&states[1], &self.sum_data_type)?)?;
Ok(())
}

Expand All @@ -195,12 +224,15 @@ impl Accumulator for AvgAccumulator {
ScalarValue::Decimal128(value, precision, scale) => {
Ok(match value {
None => ScalarValue::Decimal128(None, precision, scale),
// TODO add the checker for overflow the precision
Some(v) => ScalarValue::Decimal128(
Some(v / self.count as i128),
precision,
scale,
),
Some(value) => {
// now the sum_type and return type is not the same, need to convert the sum type to return type
calculate_result_decimal_for_avg(
value,
self.count as i128,
scale,
&self.return_data_type,
)?
}
})
}
_ => Err(DataFusionError::Internal(
Expand Down
Loading