Skip to content

Commit

Permalink
move type coercion of agg and agg_udaf to logical phase (#3768)
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 authored and pull[bot] committed Nov 3, 2022
1 parent 831bf58 commit 25a14ec
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 145 deletions.
13 changes: 4 additions & 9 deletions datafusion/core/src/physical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use arrow::{
datatypes::{DataType, Schema},
};

use super::{
expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr,
};
use super::{expressions::format_state_name, Accumulator, AggregateExpr};
use crate::error::Result;
use crate::physical_plan::PhysicalExpr;
pub use datafusion_expr::AggregateUDF;
Expand All @@ -43,18 +41,15 @@ pub fn create_aggregate_expr(
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
// coerce
let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;

let coerced_exprs_types = coerced_phy_exprs
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: coerced_phy_exprs.clone(),
data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
args: input_phy_exprs.to_vec(),
data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(),
name: name.into(),
}))
}
Expand Down
186 changes: 183 additions & 3 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
Expr, LogicalPlan, Operator,
aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown,
is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
Expand Down Expand Up @@ -407,6 +407,39 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
Expr::AggregateFunction {
fun,
args,
distinct,
filter,
} => {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
&args,
&self.schema,
&aggregate_function::signature(&fun),
)?;
let expr = Expr::AggregateFunction {
fun,
args: new_expr,
distinct,
filter,
};
Ok(expr)
}
Expr::AggregateUDF { fun, args, filter } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::AggregateUDF {
fun,
args: new_expr,
filter,
};
Ok(expr)
}
expr => Ok(expr),
}
}
Expand Down Expand Up @@ -448,6 +481,33 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
fn coerce_agg_exprs_for_signature(
agg_fun: &AggregateFunction,
input_exprs: &[Expr],
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if input_exprs.is_empty() {
return Ok(vec![]);
}
let current_types = input_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let coerced_types =
type_coercion::aggregates::coerce_types(agg_fun, &current_types, signature)?;

input_exprs
.iter()
.enumerate()
.map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
.collect::<Result<Vec<_>>>()
}

#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
Expand All @@ -456,14 +516,17 @@ mod test {
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{
cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, ColumnarValue,
cast, col, concat, concat_ws, create_udaf, is_true,
AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF,
BuiltinScalarFunction, ColumnarValue, StateTypeFunction,
};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use std::sync::Arc;

#[test]
Expand Down Expand Up @@ -596,6 +659,123 @@ mod test {
Ok(())
}

#[test]
fn agg_udaf() -> Result<()> {
let empty = empty();
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit(10i64)],
filter: None,
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn agg_udaf_invalid_input() -> Result<()> {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
let my_avg = AggregateUDF::new(
"MY_AVG",
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
&return_type,
&accumulator,
&state_type,
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit("10")],
filter: None,
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config);
assert!(plan.is_err());
assert_eq!(
"Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.\")",
&format!("{:?}", plan.err().unwrap())
);
Ok(())
}

#[test]
fn agg_function_case() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![lit(12i64)],
distinct: false,
filter: None,
};
let plan =
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: AVG(Int64(12))\n EmptyRelation",
&format!("{:?}", plan)
);

let empty = empty_with_type(DataType::Int32);
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![col("a")],
distinct: false,
filter: None,
};
let plan =
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: AVG(a)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn agg_function_invalid_input() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![lit("1")],
distinct: false,
filter: None,
};
let expr = Projection::try_new(vec![agg_expr], empty, None);
assert!(expr.is_err());
assert_eq!(
"Plan(\"The function Avg does not support inputs of type Utf8.\")",
&format!("{:?}", expr.err().unwrap())
);
Ok(())
}

#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
Loading

0 comments on commit 25a14ec

Please sign in to comment.