Skip to content

Commit

Permalink
WIP: Convert ARRAY_AGG and NTH_VALUE to UDAF
Browse files Browse the repository at this point in the history
Still has not test failures that needs to be addressed.
  • Loading branch information
eejbyfeldt committed Jun 20, 2024
1 parent 89def2c commit 9f3d98e
Show file tree
Hide file tree
Showing 28 changed files with 561 additions and 472 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,10 +1590,10 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition,
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::count_distinct;
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};

Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, true),
false
true
),])
);

Expand Down
22 changes: 2 additions & 20 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Aggregate function module contains all built-in aggregate functions definitions
use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, Volatility};

use arrow::datatypes::{DataType, Field};
use arrow::datatypes::DataType;
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use strum_macros::EnumIter;
Expand All @@ -39,10 +38,6 @@ pub enum AggregateFunction {
Max,
/// Average
Avg,
/// Aggregation into an array
ArrayAgg,
/// N'th value in a group according to some ordering
NthValue,
/// Correlation
Correlation,
/// Grouping
Expand All @@ -56,8 +51,6 @@ impl AggregateFunction {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
Grouping => "GROUPING",
}
Expand All @@ -79,8 +72,6 @@ impl FromStr for AggregateFunction {
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
// statistical
"corr" => AggregateFunction::Correlation,
// other
Expand Down Expand Up @@ -124,13 +115,7 @@ impl AggregateFunction {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
true,
)))),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
}
}
}
Expand All @@ -153,9 +138,7 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Grouping => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand All @@ -171,7 +154,6 @@ impl AggregateFunction {
AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the array_agg() aggregate function
pub fn array_agg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::ArrayAgg,
vec![expr],
false,
None,
None,
None,
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
8 changes: 1 addition & 7 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
// unpack the dictionary to get the value
Expand Down Expand Up @@ -131,7 +130,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}
Expand Down Expand Up @@ -383,11 +381,7 @@ mod tests {

// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
];
let funs = vec![AggregateFunction::Min, AggregateFunction::Max];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/lib.rs"
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Expand All @@ -48,3 +49,6 @@ datafusion-physical-expr-common = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }

[dev-dependencies]
arrow-buffer = { workspace = true }
Loading

0 comments on commit 9f3d98e

Please sign in to comment.