Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ignore] Convert Average to UDAF #10958

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
// FIXME: Import avg from udaf
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};

Expand Down Expand Up @@ -547,14 +548,15 @@ impl DataFrame {
.collect::<Vec<_>>(),
),
// mean aggregation
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
// FIXME Uncomment when avg is implement
// self.clone().aggregate(
// vec![],
// original_schema_fields
// .clone()
// .filter(|f| f.data_type().is_numeric())
// .map(|f| avg(col(f.name())).alias(f.name()))
// .collect::<Vec<_>>(),
// ),
// std aggregation
self.clone().aggregate(
vec![],
Expand Down Expand Up @@ -1807,7 +1809,8 @@ mod tests {
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
// FIXME Uncomment when avg is implement
// avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
Expand Down
23 changes: 1 addition & 22 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub enum AggregateFunction {
/// Maximum
Max,
/// Average
Avg,
// Avg,
/// Aggregation into an array
ArrayAgg,
/// N'th value in a group according to some ordering
Expand Down Expand Up @@ -67,7 +67,6 @@ impl AggregateFunction {
match self {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
Expand All @@ -93,14 +92,12 @@ impl FromStr for AggregateFunction {
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
// general
"avg" => AggregateFunction::Avg,
"bit_and" => AggregateFunction::BitAnd,
"bit_or" => AggregateFunction::BitOr,
"bit_xor" => AggregateFunction::BitXor,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
Expand Down Expand Up @@ -153,7 +150,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand All @@ -166,19 +162,6 @@ impl AggregateFunction {
}
}

/// Returns the internal sum datatype of the avg aggregate function.
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
&fun,
input_expr_types,
&fun.signature(),
)?;
avg_sum_type(&coerced_data_types[0])
}

impl AggregateFunction {
/// the signatures supported by the function `fun`.
pub fn signature(&self) -> Signature {
Expand Down Expand Up @@ -207,10 +190,6 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
}

AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,6 @@ pub fn array_agg(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Avg,
vec![expr],
false,
None,
None,
None,
))
}

/// Return a new expression with bitwise AND
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Expand Down
73 changes: 0 additions & 73 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,6 @@ pub fn coerce_types(
// unpack the dictionary to get the value
get_min_max_result_type(input_types)
}
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
let v = match &input_types[0] {
Decimal128(p, s) => Decimal128(*p, *s),
Decimal256(p, s) => Decimal256(*p, *s),
d if d.is_numeric() => Float64,
Dictionary(_, v) => {
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
}
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
)
}
};
Ok(vec![v])
}
AggregateFunction::BitAnd
| AggregateFunction::BitOr
| AggregateFunction::BitXor => {
Expand Down Expand Up @@ -422,59 +402,6 @@ pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
mod tests {
use super::*;

#[test]
fn test_aggregate_coerce_types() {
// test input args with error number input types
let fun = AggregateFunction::Min;
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());

let fun = AggregateFunction::Avg;
// test input args is invalid data type for avg
let input_types = vec![DataType::Utf8];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Avg does not support inputs of type Utf8.",
result.unwrap_err().strip_backtrace()
);

// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
vec![DataType::Decimal256(1, 1)],
vec![DataType::Utf8],
];
for fun in funs {
for input_type in &input_types {
let signature = fun.signature();
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}

// test avg
let fun = AggregateFunction::Avg;
let signature = fun.signature();
let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap();
assert_eq!(r[0], DataType::Float64);
let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap();
assert_eq!(r[0], DataType::Float64);
let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal128(20, 3));
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal256(20, 3));
}

#[test]
fn test_avg_return_data_type() -> Result<()> {
let data_type = DataType::Decimal128(10, 5);
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub mod variance;
pub mod approx_median;
pub mod approx_percentile_cont;
pub mod approx_percentile_cont_with_weight;
mod average;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
Expand Down
74 changes: 2 additions & 72 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use arrow::datatypes::Schema;
use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_expr::AggregateFunction;

#[allow(unused)]
use crate::aggregate::average::Avg;
use crate::expressions::{self, Literal};
use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
Expand Down Expand Up @@ -138,12 +139,6 @@ pub fn create_aggregate_expr(
name,
data_type,
)),
(AggregateFunction::Avg, false) => {
Arc::new(Avg::new(input_phy_exprs[0].clone(), name, data_type))
}
(AggregateFunction::Avg, true) => {
return not_impl_err!("AVG(DISTINCT) aggregations are not available");
}
(AggregateFunction::Correlation, false) => {
Arc::new(expressions::Correlation::new(
input_phy_exprs[0].clone(),
Expand Down Expand Up @@ -202,7 +197,7 @@ mod tests {
use datafusion_expr::{type_coercion, Signature};

use crate::expressions::{
try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
try_cast, ArrayAgg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
DistinctArrayAgg, Max, Min,
};

Expand Down Expand Up @@ -415,44 +410,6 @@ mod tests {
Ok(())
}

#[test]
fn test_sum_avg_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Avg];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Avg {
assert!(result_agg_phy_exprs.as_any().is::<Avg>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
);
};
}
}
Ok(())
}

#[test]
fn test_min_max() -> Result<()> {
let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?;
Expand All @@ -474,33 +431,6 @@ mod tests {
Ok(())
}

#[test]
fn test_avg_return_type() -> Result<()> {
let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);

let observed = AggregateFunction::Avg.return_type(&[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = AggregateFunction::Avg.return_type(&[DataType::Int32])?;
assert_eq!(DataType::Float64, observed);

let observed =
AggregateFunction::Avg.return_type(&[DataType::Decimal128(10, 6)])?;
assert_eq!(DataType::Decimal128(14, 10), observed);

let observed =
AggregateFunction::Avg.return_type(&[DataType::Decimal128(36, 6)])?;
assert_eq!(DataType::Decimal128(38, 10), observed);
Ok(())
}

#[test]
fn test_avg_no_utf8() {
let observed = AggregateFunction::Avg.return_type(&[DataType::Utf8]);
assert!(observed.is_err());
}

// Helper function
// Create aggregate expr with type coercion
fn create_physical_agg_expr_for_test(
Expand Down
3 changes: 2 additions & 1 deletion datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr;
pub(crate) mod array_agg;
pub(crate) mod array_agg_distinct;
pub(crate) mod array_agg_ordered;
pub(crate) mod average;
// FIXME: Delete Me
mod average;
pub(crate) mod bit_and_or_xor;
pub(crate) mod bool_and_or;
pub(crate) mod correlation;
Expand Down
2 changes: 0 additions & 2 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ pub mod helpers {
pub use crate::aggregate::array_agg::ArrayAgg;
pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg;
pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg;
pub use crate::aggregate::average::Avg;
pub use crate::aggregate::average::AvgAccumulator;
pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor};
pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr};
pub use crate::aggregate::build_in::create_aggregate_expr;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ enum AggregateFunction {
MIN = 0;
MAX = 1;
// SUM = 2;
AVG = 3;
// AVG = 3;
// COUNT = 4;
// APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading