Skip to content

Commit

Permalink
support type coercion for ScalarFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Oct 8, 2022
1 parent 7c5c2e5 commit 982b46c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 42 deletions.
66 changes: 51 additions & 15 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, Expr,
LogicalPlan, Operator,
function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
Expr, LogicalPlan, Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
Expand Down Expand Up @@ -310,18 +310,6 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
};
Ok(expr)
}
Expr::InList {
expr,
list,
Expand Down Expand Up @@ -401,6 +389,30 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
};
Ok(expr)
}
Expr::ScalarFunction { fun, args } => {
let nex_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&function::signature(&fun),
)?;
let expr = Expr::ScalarFunction {
fun,
args: nex_expr,
};
Ok(expr)
}
expr => Ok(expr),
}
}
Expand Down Expand Up @@ -449,7 +461,7 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{cast, col, is_true, ColumnarValue};
use datafusion_expr::{cast, col, is_true, BuiltinScalarFunction, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand Down Expand Up @@ -564,6 +576,30 @@ mod test {
Ok(())
}

#[test]
fn scalar_function() -> Result<()> {
let empty = empty();
let lit_expr = lit(10i64);
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
let scalar_function_expr = Expr::ScalarFunction {
fun,
args: vec![lit_expr],
};
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
None,
)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
69 changes: 42 additions & 27 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ use crate::execution_props::ExecutionProps;
use crate::{
array_expressions, conditional_expressions, datetime_expressions,
expressions::{cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS},
math_expressions, string_expressions, struct_expressions,
type_coercion::coerce,
PhysicalExpr, ScalarFunctionExpr,
math_expressions, string_expressions, struct_expressions, PhysicalExpr,
ScalarFunctionExpr,
};
use arrow::{
array::ArrayRef,
Expand All @@ -58,23 +57,20 @@ pub fn create_physical_expr(
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
let coerced_phy_exprs =
coerce(input_phy_exprs, input_schema, &function::signature(fun))?;

let coerced_expr_types = coerced_phy_exprs
let input_expr_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

let data_type = function::return_type(fun, &coerced_expr_types)?;
let data_type = function::return_type(fun, &input_expr_types)?;

let fun_expr: ScalarFunctionImplementation = match fun {
// These functions need args and input schema to pick an implementation
// Unlike the string functions, which actually figure out the function to use with each array,
// here we return either a cast fn or string timestamp translation based on the expression data type
// so we don't have to pay a per-array/batch cost.
BuiltinScalarFunction::ToTimestamp => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -89,12 +85,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampMillis => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -109,12 +105,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_millis",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampMicros => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -129,12 +125,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_micros",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampSeconds => Arc::new({
match coerced_phy_exprs[0].data_type(input_schema) {
match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -149,12 +145,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_seconds",
other,
)))
)));
}
}
}),
BuiltinScalarFunction::FromUnixtime => Arc::new({
match coerced_phy_exprs[0].data_type(input_schema) {
match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) => |col_values: &[ColumnarValue]| {
cast_column(
&col_values[0],
Expand All @@ -166,12 +162,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function from_unixtime",
other,
)))
)));
}
}
}),
BuiltinScalarFunction::ArrowTypeof => {
let input_data_type = coerced_phy_exprs[0].data_type(input_schema)?;
let input_data_type = input_phy_exprs[0].data_type(input_schema)?;
Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
"{}",
Expand All @@ -186,7 +182,7 @@ pub fn create_physical_expr(
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{}", fun),
fun_expr,
coerced_phy_exprs,
input_phy_exprs.to_vec(),
&data_type,
)))
}
Expand Down Expand Up @@ -727,7 +723,7 @@ pub fn create_physical_fun(
return Err(DataFusionError::Internal(format!(
"create_physical_fun: Unsupported scalar function {:?}",
fun
)))
)));
}
})
}
Expand All @@ -737,6 +733,7 @@ mod tests {
use super::*;
use crate::expressions::{col, lit};
use crate::from_slice::FromSlice;
use crate::type_coercion::coerce;
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array,
Expand Down Expand Up @@ -764,7 +761,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from_slice(&[1]))];

let expr =
create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?;
create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?;

// type is correct
assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE);
Expand Down Expand Up @@ -2683,7 +2680,12 @@ mod tests {
];

for fun in funs.iter() {
let expr = create_physical_expr(fun, &[], &schema, &execution_props);
let expr = create_physical_expr_with_type_coercion(
fun,
&[],
&schema,
&execution_props,
);

match expr {
Ok(..) => {
Expand Down Expand Up @@ -2720,7 +2722,7 @@ mod tests {
let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random];

for fun in funs.iter() {
create_physical_expr(fun, &[], &schema, &execution_props)?;
create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?;
}
Ok(())
}
Expand All @@ -2739,7 +2741,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![value1, value2];
let execution_props = ExecutionProps::new();

let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::MakeArray,
&[col("a", &schema)?, col("b", &schema)?],
&schema,
Expand Down Expand Up @@ -2805,7 +2807,7 @@ mod tests {
let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"]));
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> = vec![col_value];
let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col("a", &schema)?, pattern],
&schema,
Expand Down Expand Up @@ -2844,7 +2846,7 @@ mod tests {
let col_value = lit("aaa-555");
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from_slice(&[1]))];
let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col_value, pattern],
&schema,
Expand Down Expand Up @@ -2872,4 +2874,17 @@ mod tests {

Ok(())
}

// Helper function
// The type coercion will be done in the logical phase, should do the type coercion for the test
fn create_physical_expr_with_type_coercion(
fun: &BuiltinScalarFunction,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
let type_coerced_phy_exprs =
coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap();
create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props)
}
}

0 comments on commit 982b46c

Please sign in to comment.