-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move type coercion of agg and agg_udaf to logical phase #3768
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{ | |
}; | ||
use datafusion_expr::utils::from_plan; | ||
use datafusion_expr::{ | ||
function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, | ||
Expr, LogicalPlan, Operator, | ||
aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, | ||
is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, | ||
}; | ||
use datafusion_expr::{ExprSchemable, Signature}; | ||
use std::sync::Arc; | ||
|
@@ -407,6 +407,39 @@ impl ExprRewriter for TypeCoercionRewriter { | |
}; | ||
Ok(expr) | ||
} | ||
Expr::AggregateFunction { | ||
fun, | ||
args, | ||
distinct, | ||
filter, | ||
} => { | ||
let new_expr = coerce_agg_exprs_for_signature( | ||
&fun, | ||
&args, | ||
&self.schema, | ||
&aggregate_function::signature(&fun), | ||
)?; | ||
let expr = Expr::AggregateFunction { | ||
fun, | ||
args: new_expr, | ||
distinct, | ||
filter, | ||
}; | ||
Ok(expr) | ||
} | ||
Expr::AggregateUDF { fun, args, filter } => { | ||
let new_expr = coerce_arguments_for_signature( | ||
args.as_slice(), | ||
&self.schema, | ||
&fun.signature, | ||
)?; | ||
let expr = Expr::AggregateUDF { | ||
fun, | ||
args: new_expr, | ||
filter, | ||
}; | ||
Ok(expr) | ||
} | ||
expr => Ok(expr), | ||
} | ||
} | ||
|
@@ -448,6 +481,33 @@ fn coerce_arguments_for_signature( | |
.collect::<Result<Vec<_>>>() | ||
} | ||
|
||
/// Returns the coerced exprs for each `input_exprs`. | ||
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the | ||
/// data type of `input_exprs` need to be coerced. | ||
fn coerce_agg_exprs_for_signature( | ||
agg_fun: &AggregateFunction, | ||
input_exprs: &[Expr], | ||
schema: &DFSchema, | ||
signature: &Signature, | ||
) -> Result<Vec<Expr>> { | ||
if input_exprs.is_empty() { | ||
return Ok(vec![]); | ||
} | ||
let current_types = input_exprs | ||
.iter() | ||
.map(|e| e.get_type(schema)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let coerced_types = | ||
type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; | ||
|
||
input_exprs | ||
.iter() | ||
.enumerate() | ||
.map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema)) | ||
.collect::<Result<Vec<_>>>() | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter}; | ||
|
@@ -456,14 +516,17 @@ mod test { | |
use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; | ||
use datafusion_expr::expr_rewriter::ExprRewritable; | ||
use datafusion_expr::{ | ||
cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, ColumnarValue, | ||
cast, col, concat, concat_ws, create_udaf, is_true, | ||
AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, | ||
BuiltinScalarFunction, ColumnarValue, StateTypeFunction, | ||
}; | ||
use datafusion_expr::{ | ||
lit, | ||
logical_plan::{EmptyRelation, Projection}, | ||
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, | ||
Signature, Volatility, | ||
}; | ||
use datafusion_physical_expr::expressions::AvgAccumulator; | ||
use std::sync::Arc; | ||
|
||
#[test] | ||
|
@@ -596,6 +659,123 @@ mod test { | |
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn agg_udaf() -> Result<()> { | ||
let empty = empty(); | ||
let my_avg = create_udaf( | ||
"MY_AVG", | ||
DataType::Float64, | ||
Arc::new(DataType::Float64), | ||
Volatility::Immutable, | ||
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), | ||
Arc::new(vec![DataType::UInt64, DataType::Float64]), | ||
); | ||
let udaf = Expr::AggregateUDF { | ||
fun: Arc::new(my_avg), | ||
args: vec![lit(10i64)], | ||
filter: None, | ||
}; | ||
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?); | ||
let rule = TypeCoercion::new(); | ||
let mut config = OptimizerConfig::default(); | ||
let plan = rule.optimize(&plan, &mut config)?; | ||
assert_eq!( | ||
"Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation", | ||
&format!("{:?}", plan) | ||
); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn agg_udaf_invalid_input() -> Result<()> { | ||
let empty = empty(); | ||
let return_type: ReturnTypeFunction = | ||
Arc::new(move |_| Ok(Arc::new(DataType::Float64))); | ||
let state_type: StateTypeFunction = | ||
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); | ||
let accumulator: AccumulatorFunctionImplementation = | ||
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))); | ||
let my_avg = AggregateUDF::new( | ||
"MY_AVG", | ||
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), | ||
&return_type, | ||
&accumulator, | ||
&state_type, | ||
); | ||
let udaf = Expr::AggregateUDF { | ||
fun: Arc::new(my_avg), | ||
args: vec![lit("10")], | ||
filter: None, | ||
}; | ||
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?); | ||
let rule = TypeCoercion::new(); | ||
let mut config = OptimizerConfig::default(); | ||
let plan = rule.optimize(&plan, &mut config); | ||
assert!(plan.is_err()); | ||
assert_eq!( | ||
"Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.\")", | ||
&format!("{:?}", plan.err().unwrap()) | ||
); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn agg_function_case() -> Result<()> { | ||
let empty = empty(); | ||
let fun: AggregateFunction = AggregateFunction::Avg; | ||
let agg_expr = Expr::AggregateFunction { | ||
fun, | ||
args: vec![lit(12i64)], | ||
distinct: false, | ||
filter: None, | ||
}; | ||
let plan = | ||
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?); | ||
let rule = TypeCoercion::new(); | ||
let mut config = OptimizerConfig::default(); | ||
let plan = rule.optimize(&plan, &mut config)?; | ||
assert_eq!( | ||
"Projection: AVG(Int64(12))\n EmptyRelation", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any coercion happening here. Maybe I am missing something There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you didn't missing anything. You can take a look |
||
&format!("{:?}", plan) | ||
); | ||
|
||
let empty = empty_with_type(DataType::Int32); | ||
let fun: AggregateFunction = AggregateFunction::Avg; | ||
let agg_expr = Expr::AggregateFunction { | ||
fun, | ||
args: vec![col("a")], | ||
distinct: false, | ||
filter: None, | ||
}; | ||
let plan = | ||
LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?); | ||
let plan = rule.optimize(&plan, &mut config)?; | ||
assert_eq!( | ||
"Projection: AVG(a)\n EmptyRelation", | ||
&format!("{:?}", plan) | ||
); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn agg_function_invalid_input() -> Result<()> { | ||
let empty = empty(); | ||
let fun: AggregateFunction = AggregateFunction::Avg; | ||
let agg_expr = Expr::AggregateFunction { | ||
fun, | ||
args: vec![lit("1")], | ||
distinct: false, | ||
filter: None, | ||
}; | ||
let expr = Projection::try_new(vec![agg_expr], empty, None); | ||
assert!(expr.is_err()); | ||
assert_eq!( | ||
"Plan(\"The function Avg does not support inputs of type Utf8.\")", | ||
&format!("{:?}", expr.err().unwrap()) | ||
); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn binary_op_date32_add_interval() -> Result<()> { | ||
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍