From 7a70b1147238ba82ffe010f696cd7f4aee13436d Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Fri, 7 Oct 2022 22:05:34 +0800 Subject: [PATCH 1/3] move type coercion for agg function --- datafusion/optimizer/src/type_coercion.rs | 110 +++++++++- .../physical-expr/src/aggregate/build_in.rs | 188 ++++++++++-------- .../src/aggregate/coercion_rule.rs | 54 ----- datafusion/physical-expr/src/aggregate/mod.rs | 1 - 4 files changed, 217 insertions(+), 136 deletions(-) delete mode 100644 datafusion/physical-expr/src/aggregate/coercion_rule.rs diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index bb236fdde527..fb459f5590b6 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -29,8 +29,8 @@ use datafusion_expr::type_coercion::other::{ }; use datafusion_expr::utils::from_plan; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, Expr, - LogicalPlan, Operator, + aggregate_function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, + is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, }; use datafusion_expr::{ExprSchemable, Signature}; use std::sync::Arc; @@ -322,6 +322,26 @@ impl ExprRewriter for TypeCoercionRewriter { }; Ok(expr) } + Expr::AggregateFunction { + fun, + args, + distinct, + filter, + } => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &aggregate_function::signature(&fun), + )?; + let expr = Expr::AggregateFunction { + fun, + args: new_expr, + distinct, + filter, + }; + Ok(expr) + } Expr::InList { expr, list, @@ -442,6 +462,33 @@ fn coerce_arguments_for_signature( .collect::>>() } +/// Returns the coerced exprs for each `input_exprs`. +/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the +/// data type of `input_exprs` need to be coerced. +fn coerce_agg_exprs_for_signature( + agg_fun: &AggregateFunction, + input_exprs: &[Expr], + schema: &DFSchema, + signature: &Signature, +) -> Result> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let current_types = input_exprs + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let coerced_types = + type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; + + input_exprs + .iter() + .enumerate() + .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema)) + .collect::>>() +} + #[cfg(test)] mod test { use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter}; @@ -449,7 +496,7 @@ mod test { use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, is_true, ColumnarValue}; + use datafusion_expr::{cast, col, is_true, AggregateFunction, ColumnarValue}; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, @@ -564,6 +611,63 @@ mod test { Ok(()) } + #[test] + fn agg_function_case() -> Result<()> { + let empty = empty(); + let fun: AggregateFunction = AggregateFunction::Avg; + let agg_expr = Expr::AggregateFunction { + fun, + args: vec![lit(12i64)], + distinct: false, + filter: None, + }; + let plan = + LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: AVG(Int64(12))\n EmptyRelation", + &format!("{:?}", plan) + ); + + let empty = empty_with_type(DataType::Int32); + let fun: AggregateFunction = AggregateFunction::Avg; + let agg_expr = Expr::AggregateFunction { + fun, + args: vec![col("a")], + distinct: false, + filter: None, + }; + let plan = + LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: AVG(a)\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + + #[test] + fn agg_function_invalid_input() -> Result<()> { + let empty = empty(); + let fun: AggregateFunction = AggregateFunction::Avg; + let agg_expr = Expr::AggregateFunction { + fun, + args: vec![lit("1")], + distinct: false, + filter: None, + }; + let expr = Projection::try_new(vec![agg_expr], empty, None); + assert!(expr.is_err()); + assert_eq!( + "Plan(\"The function Avg does not support inputs of type Utf8.\")", + &format!("{:?}", expr.err().unwrap()) + ); + Ok(()) + } + #[test] fn binary_op_date32_add_interval() -> Result<()> { //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index e3154488cf8f..597b5157554b 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,11 +26,9 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. -use crate::aggregate::coercion_rule::coerce_exprs; use crate::{expressions, AggregateExpr, PhysicalExpr}; use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::aggregate_function; use datafusion_expr::aggregate_function::return_type; pub use datafusion_expr::AggregateFunction; use std::sync::Arc; @@ -45,89 +43,72 @@ pub fn create_aggregate_expr( name: impl Into, ) -> Result> { let name = name.into(); - // get the coerced phy exprs if some expr need to be wrapped with the try cast. - let coerced_phy_exprs = coerce_exprs( - fun, - input_phy_exprs, - input_schema, - &aggregate_function::signature(fun), - )?; - if coerced_phy_exprs.is_empty() { - return Err(DataFusionError::Plan(format!( - "Invalid or wrong number of arguments passed to aggregate: '{}'", - name, - ))); - } - let coerced_exprs_types = coerced_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - // get the result data type for this aggregate function let input_phy_types = input_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; let return_type = return_type(fun, &input_phy_types)?; + let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - coerced_exprs_types, - coerced_phy_exprs, + input_phy_types, + input_phy_exprs, name, return_type, )), (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), (AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new( - vec![coerced_phy_exprs[0].clone()], + vec![input_phy_exprs[0].clone()], name, return_type, )), (AggregateFunction::ApproxDistinct, _) => { Arc::new(expressions::ApproxDistinct::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, - coerced_exprs_types[0].clone(), + input_phy_types[0].clone(), )) } (AggregateFunction::ArrayAgg, false) => Arc::new(expressions::ArrayAgg::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, - coerced_exprs_types[0].clone(), + input_phy_types[0].clone(), )), (AggregateFunction::ArrayAgg, true) => { Arc::new(expressions::DistinctArrayAgg::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, - coerced_exprs_types[0].clone(), + input_phy_types[0].clone(), )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), @@ -137,7 +118,7 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), @@ -146,21 +127,17 @@ pub fn create_aggregate_expr( "VAR(DISTINCT) aggregations are not available".to_string(), )); } - (AggregateFunction::VariancePop, false) => { - Arc::new(expressions::VariancePop::new( - coerced_phy_exprs[0].clone(), - name, - return_type, - )) - } + (AggregateFunction::VariancePop, false) => Arc::new( + expressions::VariancePop::new(input_phy_exprs[0].clone(), name, return_type), + ), (AggregateFunction::VariancePop, true) => { return Err(DataFusionError::NotImplemented( "VAR_POP(DISTINCT) aggregations are not available".to_string(), )); } (AggregateFunction::Covariance, false) => Arc::new(expressions::Covariance::new( - coerced_phy_exprs[0].clone(), - coerced_phy_exprs[1].clone(), + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), name, return_type, )), @@ -171,8 +148,8 @@ pub fn create_aggregate_expr( } (AggregateFunction::CovariancePop, false) => { Arc::new(expressions::CovariancePop::new( - coerced_phy_exprs[0].clone(), - coerced_phy_exprs[1].clone(), + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), name, return_type, )) @@ -183,7 +160,7 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), @@ -193,7 +170,7 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), @@ -204,8 +181,8 @@ pub fn create_aggregate_expr( } (AggregateFunction::Correlation, false) => { Arc::new(expressions::Correlation::new( - coerced_phy_exprs[0].clone(), - coerced_phy_exprs[1].clone(), + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), name, return_type, )) @@ -216,17 +193,17 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::ApproxPercentileCont, false) => { - if coerced_phy_exprs.len() == 2 { + if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( // Pass in the desired percentile expr - coerced_phy_exprs, + input_phy_exprs, name, return_type, )?) } else { Arc::new(expressions::ApproxPercentileCont::new_with_max_size( // Pass in the desired percentile expr - coerced_phy_exprs, + input_phy_exprs, name, return_type, )?) @@ -241,7 +218,7 @@ pub fn create_aggregate_expr( (AggregateFunction::ApproxPercentileContWithWeight, false) => { Arc::new(expressions::ApproxPercentileContWithWeight::new( // Pass in the desired percentile expr - coerced_phy_exprs, + input_phy_exprs, name, return_type, )?) @@ -254,7 +231,7 @@ pub fn create_aggregate_expr( } (AggregateFunction::ApproxMedian, false) => { Arc::new(expressions::ApproxMedian::try_new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )?) @@ -265,7 +242,7 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( - coerced_phy_exprs[0].clone(), + input_phy_exprs[0].clone(), name, return_type, )), @@ -281,13 +258,14 @@ pub fn create_aggregate_expr( mod tests { use super::*; use crate::expressions::{ - ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation, - Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, - Variance, + try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, + Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, + Stddev, Sum, Variance, }; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::aggregates::NUMERICS; + use datafusion_expr::{aggregate_function, type_coercion, Signature}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -311,7 +289,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -344,9 +322,9 @@ mod tests { DataType::List(Box::new(Field::new( "item", data_type.clone(), - true + true, ))), - false + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -354,7 +332,7 @@ mod tests { _ => {} }; - let result_distinct = create_aggregate_expr( + let result_distinct = create_physical_agg_expr_for_test( &fun, true, &input_phy_exprs[0..1], @@ -387,9 +365,9 @@ mod tests { DataType::List(Box::new(Field::new( "item", data_type.clone(), - true + true, ))), - false + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -412,7 +390,7 @@ mod tests { ), Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), ]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &AggregateFunction::ApproxPercentileCont, false, &input_phy_exprs[..], @@ -441,7 +419,7 @@ mod tests { ), Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), ]; - let err = create_aggregate_expr( + let err = create_physical_agg_expr_for_test( &AggregateFunction::ApproxPercentileCont, false, &input_phy_exprs[..], @@ -472,7 +450,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -521,7 +499,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -583,7 +561,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -621,7 +599,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -659,7 +637,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -697,7 +675,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -744,7 +722,7 @@ mod tests { .unwrap(), ), ]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..2], @@ -791,7 +769,7 @@ mod tests { .unwrap(), ), ]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..2], @@ -838,7 +816,7 @@ mod tests { .unwrap(), ), ]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..2], @@ -876,7 +854,7 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_aggregate_expr( + let result_agg_phy_exprs = create_physical_agg_expr_for_test( &fun, false, &input_phy_exprs[0..1], @@ -1065,4 +1043,58 @@ mod tests { let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); assert!(observed.is_err()); } + + // Helper function + // Create aggregate expr with type coercion + fn create_physical_agg_expr_for_test( + fun: &AggregateFunction, + distinct: bool, + input_phy_exprs: &[Arc], + input_schema: &Schema, + name: impl Into, + ) -> Result> { + let name = name.into(); + let coerced_phy_exprs = coerce_exprs_for_test( + fun, + input_phy_exprs, + input_schema, + &aggregate_function::signature(fun), + )?; + if coerced_phy_exprs.is_empty() { + return Err(DataFusionError::Plan(format!( + "Invalid or wrong number of arguments passed to aggregate: '{}'", + name, + ))); + } + create_aggregate_expr(fun, distinct, &coerced_phy_exprs, input_schema, name) + } + + // Returns the coerced exprs for each `input_exprs`. + // Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the + // data type of `input_exprs` need to be coerced. + fn coerce_exprs_for_test( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, + ) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = + type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .zip(coerced_types.into_iter()) + .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) + .collect::>>() + } } diff --git a/datafusion/physical-expr/src/aggregate/coercion_rule.rs b/datafusion/physical-expr/src/aggregate/coercion_rule.rs deleted file mode 100644 index a8c68390ae57..000000000000 --- a/datafusion/physical-expr/src/aggregate/coercion_rule.rs +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Define coercion rules for Aggregate function. - -use crate::expressions::try_cast; -use crate::PhysicalExpr; -use arrow::datatypes::Schema; -use datafusion_common::Result; -use datafusion_expr::{type_coercion, AggregateFunction, Signature}; -use std::sync::Arc; - -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -pub fn coerce_exprs( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, -) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types.into_iter()) - .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) - .collect::>>() -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index ec338eb68005..f6374687403e 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -31,7 +31,6 @@ pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod average; -pub(crate) mod coercion_rule; pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; From edfb8debdebb766188da2f19a68b10072d483ea6 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Sat, 8 Oct 2022 13:49:56 +0800 Subject: [PATCH 2/3] move type coercion for agg udaf --- datafusion/core/src/physical_plan/udaf.rs | 13 ++-- datafusion/optimizer/src/type_coercion.rs | 80 ++++++++++++++++++++++- 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs index e017bb5ad6d4..659ff560d140 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/core/src/physical_plan/udaf.rs @@ -26,9 +26,7 @@ use arrow::{ datatypes::{DataType, Schema}, }; -use super::{ - expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr, -}; +use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use crate::error::Result; use crate::physical_plan::PhysicalExpr; pub use datafusion_expr::AggregateUDF; @@ -43,18 +41,15 @@ pub fn create_aggregate_expr( input_schema: &Schema, name: impl Into, ) -> Result> { - // coerce - let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?; - - let coerced_exprs_types = coerced_phy_exprs + let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(input_schema)) .collect::>>()?; Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), - args: coerced_phy_exprs.clone(), - data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(), + args: input_phy_exprs.to_vec(), + data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), name: name.into(), })) } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index fb459f5590b6..a15c0b1fd558 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -342,6 +342,19 @@ impl ExprRewriter for TypeCoercionRewriter { }; Ok(expr) } + Expr::AggregateUDF { fun, args, filter } => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature, + )?; + let expr = Expr::AggregateUDF { + fun, + args: new_expr, + filter, + }; + Ok(expr) + } Expr::InList { expr, list, @@ -496,13 +509,17 @@ mod test { use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, is_true, AggregateFunction, ColumnarValue}; + use datafusion_expr::{ + cast, col, create_udaf, is_true, AccumulatorFunctionImplementation, + AggregateFunction, AggregateUDF, ColumnarValue, StateTypeFunction, + }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; + use datafusion_physical_expr::expressions::AvgAccumulator; use std::sync::Arc; #[test] @@ -611,6 +628,67 @@ mod test { Ok(()) } + #[test] + fn agg_udaf() -> Result<()> { + let empty = empty(); + let my_avg = create_udaf( + "MY_AVG", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + let udaf = Expr::AggregateUDF { + fun: Arc::new(my_avg), + args: vec![lit(10i64)], + filter: None, + }; + let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + + #[test] + fn agg_udaf_invalid_input() -> Result<()> { + let empty = empty(); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Float64).clone())); + let state_type: StateTypeFunction = Arc::new(move |_| { + Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]).clone()) + }); + let accumulator: AccumulatorFunctionImplementation = + Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))); + let my_avg = AggregateUDF::new( + "MY_AVG", + &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), + &return_type, + &accumulator, + &state_type, + ); + let udaf = Expr::AggregateUDF { + fun: Arc::new(my_avg), + args: vec![lit("10")], + filter: None, + }; + let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config); + assert!(plan.is_err()); + assert_eq!( + "Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.\")", + &format!("{:?}", plan.err().unwrap()) + ); + Ok(()) + } + #[test] fn agg_function_case() -> Result<()> { let empty = empty(); From 6cf7447f5c68b74e896e0fe1e288ea6a3e1402fa Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 11 Oct 2022 10:30:23 +0800 Subject: [PATCH 3/3] fix lint --- datafusion/optimizer/src/type_coercion.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index f65e8c20443a..89d5d660b885 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -690,10 +690,9 @@ mod test { fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Float64).clone())); - let state_type: StateTypeFunction = Arc::new(move |_| { - Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]).clone()) - }); + Arc::new(move |_| Ok(Arc::new(DataType::Float64))); + let state_type: StateTypeFunction = + Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))); let my_avg = AggregateUDF::new(