Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move type coercion of agg and agg_udaf to logical phase #3768

Merged
merged 4 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions datafusion/core/src/physical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use arrow::{
datatypes::{DataType, Schema},
};

use super::{
expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr,
};
use super::{expressions::format_state_name, Accumulator, AggregateExpr};
use crate::error::Result;
use crate::physical_plan::PhysicalExpr;
pub use datafusion_expr::AggregateUDF;
Expand All @@ -43,18 +41,15 @@ pub fn create_aggregate_expr(
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
// coerce
let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;

let coerced_exprs_types = coerced_phy_exprs
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: coerced_phy_exprs.clone(),
data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
args: input_phy_exprs.to_vec(),
data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(),
name: name.into(),
}))
}
Expand Down
186 changes: 183 additions & 3 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
Expr, LogicalPlan, Operator,
aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown,
is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
Expand Down Expand Up @@ -407,6 +407,39 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
Expr::AggregateFunction {
fun,
args,
distinct,
filter,
} => {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
&args,
&self.schema,
&aggregate_function::signature(&fun),
)?;
let expr = Expr::AggregateFunction {
fun,
args: new_expr,
distinct,
filter,
};
Ok(expr)
}
Expr::AggregateUDF { fun, args, filter } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::AggregateUDF {
fun,
args: new_expr,
filter,
};
Ok(expr)
}
expr => Ok(expr),
}
}
Expand Down Expand Up @@ -448,6 +481,33 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
fn coerce_agg_exprs_for_signature(
agg_fun: &AggregateFunction,
input_exprs: &[Expr],
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if input_exprs.is_empty() {
return Ok(vec![]);
}
let current_types = input_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let coerced_types =
type_coercion::aggregates::coerce_types(agg_fun, &current_types, signature)?;

input_exprs
.iter()
.enumerate()
.map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
.collect::<Result<Vec<_>>>()
}

#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
Expand All @@ -456,14 +516,17 @@ mod test {
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{
cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, ColumnarValue,
cast, col, concat, concat_ws, create_udaf, is_true,
AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF,
BuiltinScalarFunction, ColumnarValue, StateTypeFunction,
};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use std::sync::Arc;

#[test]
Expand Down Expand Up @@ -596,6 +659,123 @@ mod test {
Ok(())
}

#[test]
fn agg_udaf() -> Result<()> {
let empty = empty();
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit(10i64)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

filter: None,
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn agg_udaf_invalid_input() -> Result<()> {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
let my_avg = AggregateUDF::new(
"MY_AVG",
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
&return_type,
&accumulator,
&state_type,
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit("10")],
filter: None,
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config);
assert!(plan.is_err());
assert_eq!(
"Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.\")",
&format!("{:?}", plan.err().unwrap())
);
Ok(())
}

#[test]
fn agg_function_case() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![lit(12i64)],
distinct: false,
filter: None,
};
let plan =
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: AVG(Int64(12))\n EmptyRelation",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any coercion happening here. Maybe I am missing something

Copy link
Contributor Author

@liukun4515 liukun4515 Oct 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you didn't missing anything.

You can take a look type_coercion::aggregates::coerce_types function which just check the input data type and don't do any coercion for the function.

&format!("{:?}", plan)
);

let empty = empty_with_type(DataType::Int32);
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![col("a")],
distinct: false,
filter: None,
};
let plan =
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: AVG(a)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn agg_function_invalid_input() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction {
fun,
args: vec![lit("1")],
distinct: false,
filter: None,
};
let expr = Projection::try_new(vec![agg_expr], empty, None);
assert!(expr.is_err());
assert_eq!(
"Plan(\"The function Avg does not support inputs of type Utf8.\")",
&format!("{:?}", expr.err().unwrap())
);
Ok(())
}

#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
Loading