diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index e3d5820327ec2..83be3d4692a45 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -39,7 +39,6 @@ use crate::execution::physical_plan::common; use crate::execution::physical_plan::csv::CsvReadOptions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::planner::DefaultPhysicalPlanner; -use crate::execution::physical_plan::scalar_functions; use crate::execution::physical_plan::udf::ScalarFunction; use crate::execution::physical_plan::ExecutionPlan; use crate::execution::physical_plan::PhysicalPlanner; @@ -105,16 +104,13 @@ impl ExecutionContext { /// Create a new execution context using the provided configuration pub fn with_config(config: ExecutionConfig) -> Self { - let mut ctx = Self { + let ctx = Self { state: Arc::new(Mutex::new(ExecutionContextState { datasources: HashMap::new(), scalar_functions: HashMap::new(), config, })), }; - for udf in scalar_functions() { - ctx.register_udf(udf); - } ctx } diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index b8254f3d79061..35381066653c0 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -1361,7 +1361,7 @@ pub struct CastExpr { } /// Determine if a DataType is numeric or not -fn is_numeric(dt: &DataType) -> bool { +pub fn is_numeric(dt: &DataType) -> bool { match dt { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, diff --git a/rust/datafusion/src/execution/physical_plan/functions.rs b/rust/datafusion/src/execution/physical_plan/functions.rs new file mode 100644 index 0000000000000..8fcbe9512e43d --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/functions.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (scalar) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use super::{ + type_coercion::{coerce, data_types}, + PhysicalExpr, +}; +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::math_expressions; +use crate::execution::physical_plan::udf; +use arrow::{ + compute::kernels::length::length, + datatypes::{DataType, Schema}, +}; +use std::{fmt, str::FromStr, sync::Arc}; +use udf::ScalarUdf; + +/// A function's signature, which defines the function's supported argument types. +#[derive(Debug)] +pub enum Signature { + /// arbitrary number of arguments of an common type out of a list of valid types + // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + Variadic(Vec), + /// arbitrary number of arguments of an arbitrary but equal type + // A function such as `array` is `VariadicEqual` + // The first argument decides the type used for coercion + VariadicEqual, + /// fixed number of arguments of an arbitrary but equal type out of a list of valid types + // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` + // A function of two arguments of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + Uniform(usize, Vec), +} + +/// Enum of all built-in scalar functions +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ScalarFunction { + /// sqrt + Sqrt, + /// sin + Sin, + /// cos + Cos, + /// tan + Tan, + /// asin + Asin, + /// acos + Acos, + /// atan + Atan, + /// exp + Exp, + /// log, also known as ln + Log, + /// log2 + Log2, + /// log10 + Log10, + /// floor + Floor, + /// ceil + Ceil, + /// round + Round, + /// trunc + Trunc, + /// abs + Abs, + /// signum + Signum, + /// length + Length, +} + +impl fmt::Display for ScalarFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // lowercase of the debug. + write!(f, "{}", format!("{:?}", self).to_lowercase()) + } +} + +impl FromStr for ScalarFunction { + type Err = ExecutionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "sqrt" => ScalarFunction::Sqrt, + "sin" => ScalarFunction::Sin, + "cos" => ScalarFunction::Cos, + "tan" => ScalarFunction::Tan, + "asin" => ScalarFunction::Asin, + "acos" => ScalarFunction::Acos, + "atan" => ScalarFunction::Atan, + "exp" => ScalarFunction::Exp, + "log" => ScalarFunction::Log, + "log2" => ScalarFunction::Log2, + "log10" => ScalarFunction::Log10, + "floor" => ScalarFunction::Floor, + "ceil" => ScalarFunction::Ceil, + "round" => ScalarFunction::Round, + "truc" => ScalarFunction::Trunc, + "abs" => ScalarFunction::Abs, + "signum" => ScalarFunction::Signum, + "length" => ScalarFunction::Length, + _ => { + return Err(ExecutionError::General(format!( + "There is no built-in function named {}", + name + ))) + } + }) + } +} + +/// Returns the datatype of the scalar function +pub fn return_type(fun: &ScalarFunction, arg_types: &Vec) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(&arg_types, &signature(fun))?; + + // the return type after coercion. + // for now, this is type-independent, but there will be built-in functions whose return type + // depends on the incoming type. + match fun { + ScalarFunction::Length => Ok(DataType::UInt32), + _ => Ok(DataType::Float64), + } +} + +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &ScalarFunction, + args: &Vec>, + input_schema: &Schema, +) -> Result> { + let fun_expr: ScalarUdf = Arc::new(match fun { + ScalarFunction::Sqrt => math_expressions::sqrt, + ScalarFunction::Sin => math_expressions::sin, + ScalarFunction::Cos => math_expressions::cos, + ScalarFunction::Tan => math_expressions::tan, + ScalarFunction::Asin => math_expressions::asin, + ScalarFunction::Acos => math_expressions::acos, + ScalarFunction::Atan => math_expressions::atan, + ScalarFunction::Exp => math_expressions::exp, + ScalarFunction::Log => math_expressions::ln, + ScalarFunction::Log2 => math_expressions::log2, + ScalarFunction::Log10 => math_expressions::log10, + ScalarFunction::Floor => math_expressions::floor, + ScalarFunction::Ceil => math_expressions::ceil, + ScalarFunction::Round => math_expressions::round, + ScalarFunction::Trunc => math_expressions::trunc, + ScalarFunction::Abs => math_expressions::abs, + ScalarFunction::Signum => math_expressions::signum, + ScalarFunction::Length => |args| Ok(Arc::new(length(args[0].as_ref())?)), + }); + // coerce + let args = coerce(args, input_schema, &signature(fun))?; + + let arg_types = args + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + Ok(Arc::new(udf::ScalarFunctionExpr::new( + &format!("{}", fun), + fun_expr, + args, + &return_type(&fun, &arg_types)?, + ))) +} + +/// the signatures supported by the function `fun`. +fn signature(fun: &ScalarFunction) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + + // for now, the list is small, as we do not have many built-in functions. + match fun { + ScalarFunction::Length => Signature::Uniform(1, vec![DataType::Utf8]), + // math expressions expect 1 argument of type f64 + _ => Signature::Uniform(1, vec![DataType::Float64]), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + error::Result, execution::physical_plan::expressions::lit, + logicalplan::ScalarValue, + }; + use arrow::{ + array::{ArrayRef, Float64Array, Int32Array}, + datatypes::Field, + record_batch::RecordBatch, + }; + + fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> { + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + let arg = lit(value); + + let expr = create_physical_expr(&ScalarFunction::Exp, &vec![arg], &schema)?; + + // type is correct + assert_eq!(expr.data_type(&schema)?, DataType::Float64); + + // evaluate works + let result = + expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?; + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + + // value is correct + assert_eq!(format!("{}", result.value(0)), expected); // = exp(1) + + Ok(()) + } + + #[test] + fn test_math_function() -> Result<()> { + let exp_f64 = "2.718281828459045"; + generic_test_math(ScalarValue::Int32(1i32), exp_f64)?; + generic_test_math(ScalarValue::UInt32(1u32), exp_f64)?; + generic_test_math(ScalarValue::Float64(1f64), exp_f64)?; + generic_test_math(ScalarValue::Float32(1f32), exp_f64)?; + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index ea40ac5efc613..7ea0795bb74c1 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -17,102 +17,52 @@ //! Math expressions -use crate::error::ExecutionError; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::error::{ExecutionError, Result}; use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; -use arrow::datatypes::DataType; use std::sync::Arc; macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { - ScalarFunction::new( - $NAME, - vec![DataType::Float64], - DataType::Float64, - Arc::new(|args: &[ArrayRef]| { - let n = &args[0].as_any().downcast_ref::(); - match n { - Some(array) => { - let mut builder = Float64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array.value(i).$FUNC())?; - } + /// mathematical function + pub fn $FUNC(args: &[ArrayRef]) -> Result { + let n = &args[0].as_any().downcast_ref::(); + match n { + Some(array) => { + let mut builder = Float64Builder::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i).$FUNC())?; } - Ok(Arc::new(builder.finish())) } - _ => Err(ExecutionError::General(format!( - "Invalid data type for {}", - $NAME - ))), + Ok(Arc::new(builder.finish())) } - }), - ) + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), + } + } }; } -/// vector of math scalar functions -pub fn scalar_functions() -> Vec { - vec![ - math_unary_function!("sqrt", sqrt), - math_unary_function!("sin", sin), - math_unary_function!("cos", cos), - math_unary_function!("tan", tan), - math_unary_function!("asin", asin), - math_unary_function!("acos", acos), - math_unary_function!("atan", atan), - math_unary_function!("floor", floor), - math_unary_function!("ceil", ceil), - math_unary_function!("round", round), - math_unary_function!("trunc", trunc), - math_unary_function!("abs", abs), - math_unary_function!("signum", signum), - math_unary_function!("exp", exp), - math_unary_function!("log", ln), - math_unary_function!("log2", log2), - math_unary_function!("log10", log10), - ] -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::Result; - use crate::{ - execution::context::ExecutionContext, - logicalplan::{col, sqrt, LogicalPlanBuilder}, - }; - use arrow::datatypes::{Field, Schema}; - - #[test] - fn cast_i8_input() -> Result<()> { - let schema = Schema::new(vec![Field::new("c0", DataType::Int8, true)]); - let plan = LogicalPlanBuilder::scan("", "", &schema, None)? - .project(vec![sqrt(col("c0"))])? - .build()?; - let ctx = ExecutionContext::new(); - let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(CAST(#c0 AS Float64))\ - \n TableScan: projection=Some([0])"; - assert_eq!(format!("{:?}", plan), expected); - Ok(()) - } - - #[test] - fn no_cast_f64_input() -> Result<()> { - let schema = Schema::new(vec![Field::new("c0", DataType::Float64, true)]); - let plan = LogicalPlanBuilder::scan("", "", &schema, None)? - .project(vec![sqrt(col("c0"))])? - .build()?; - let ctx = ExecutionContext::new(); - let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(#c0)\ - \n TableScan: projection=Some([0])"; - assert_eq!(format!("{:?}", plan), expected); - Ok(()) - } -} +math_unary_function!("sqrt", sqrt); +math_unary_function!("sin", sin); +math_unary_function!("cos", cos); +math_unary_function!("tan", tan); +math_unary_function!("asin", asin); +math_unary_function!("acos", acos); +math_unary_function!("atan", atan); +math_unary_function!("floor", floor); +math_unary_function!("ceil", ceil); +math_unary_function!("round", round); +math_unary_function!("trunc", trunc); +math_unary_function!("abs", abs); +math_unary_function!("signum", signum); +math_unary_function!("exp", exp); +math_unary_function!("log", ln); +math_unary_function!("log2", log2); +math_unary_function!("log10", log10); diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 780cdba395dbf..7a6ed865e33fb 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -27,11 +27,7 @@ use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Schema, SchemaRef}; -use arrow::{ - compute::kernels::length::length, - record_batch::{RecordBatch, RecordBatchReader}, -}; -use udf::ScalarFunction; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. @@ -134,23 +130,12 @@ pub trait Accumulator: Debug { fn get_value(&self) -> Result>; } -/// Vector of scalar functions declared in this module -pub fn scalar_functions() -> Vec { - let mut udfs = vec![ScalarFunction::new( - "length", - vec![DataType::Utf8], - DataType::UInt32, - Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), - )]; - udfs.append(&mut math_expressions::scalar_functions()); - udfs -} - pub mod common; pub mod csv; pub mod explain; pub mod expressions; pub mod filter; +pub mod functions; pub mod hash_aggregate; pub mod limit; pub mod math_expressions; @@ -160,4 +145,5 @@ pub mod parquet; pub mod planner; pub mod projection; pub mod sort; +pub mod type_coercion; pub mod udf; diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 18b86a0fe5e3e..1c24cd3f30445 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use super::expressions::binary; +use super::{expressions::binary, functions}; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; @@ -356,7 +356,16 @@ impl DefaultPhysicalPlanner { input_schema, data_type.clone(), ), - Expr::ScalarFunction { + Expr::ScalarFunction { fun, args } => { + let physical_args = args + .iter() + .map(|e| { + self.create_physical_expr(e, input_schema, ctx_state.clone()) + }) + .collect::>>()?; + functions::create_physical_expr(fun, &physical_args, input_schema) + } + Expr::ScalarUDF { name, args, return_type, @@ -372,7 +381,7 @@ impl DefaultPhysicalPlanner { } Ok(Arc::new(ScalarFunctionExpr::new( name, - Box::new(f.fun.clone()), + f.fun.clone(), physical_args, return_type, ))) diff --git a/rust/datafusion/src/execution/physical_plan/type_coercion.rs b/rust/datafusion/src/execution/physical_plan/type_coercion.rs new file mode 100644 index 0000000000000..dd08cd2cb4b70 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/type_coercion.rs @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Type coercion rules for functions with multiple valid signatures + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Schema}; + +use super::{functions::Signature, PhysicalExpr}; +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::expressions::{cast, numerical_coercion}; +use crate::logicalplan::Operator; + +/// Returns expressions constructed by casting `expressions` to types compatible with `signatures`. +pub fn coerce( + expressions: &Vec>, + schema: &Schema, + signature: &Signature, +) -> Result>> { + let current_types = expressions + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + let new_types = data_types(¤t_types, signature)?; + + expressions + .iter() + .enumerate() + .map(|(i, expr)| cast(expr.clone(), &schema, new_types[i].clone())) + .collect::>>() +} + +/// returns the data types that each argument must be casted to match the `signature`. +pub fn data_types( + current_types: &Vec, + signature: &Signature, +) -> Result> { + let valid_types = match signature { + Signature::Variadic(valid_types) => valid_types + .iter() + .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) + .collect(), + Signature::Uniform(number, valid_types) => valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect(), + Signature::VariadicEqual => { + // one entry with the same len as current_types, whose type is `current_types[0]`. + vec![current_types + .iter() + .map(|_| current_types[0].clone()) + .collect()] + } + }; + + if valid_types.contains(current_types) { + return Ok(current_types.clone()); + } + + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { + return Ok(types); + } + } + + // none possible -> Error + Err(ExecutionError::General(format!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, signature + ))) +} + +/// Try to coerce current_types into valid_types. +fn maybe_data_types( + valid_types: &Vec, + current_types: &Vec, +) -> Option> { + if valid_types.len() != current_types.len() { + return None; + } + + let mut new_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + + if current_type == valid_type { + new_type.push(current_type.clone()) + } else { + // attempt to coerce using numerical coercion + // todo: also try string coercion. + if let Ok(cast_to_type) = numerical_coercion( + ¤t_type, + // assume that the function behaves like plus + // plus is not special here; this function is just trying its best... + &Operator::Plus, + valid_type, + ) { + new_type.push(cast_to_type) + } else { + // not possible + return None; + } + } + } + Some(new_type) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::physical_plan::expressions::col; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn test_maybe_data_types() -> Result<()> { + // this vec contains: arg1, arg2, expected result + let cases = vec![ + // 2 entries, same values + ( + vec![DataType::UInt8, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt8, DataType::UInt16]), + ), + // 2 entries, can coerse values + ( + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt16, DataType::UInt16]), + ), + // 0 entries, all good + (vec![], vec![], Some(vec![])), + // 2 entries, can't coerce + ( + vec![DataType::Boolean, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + None, + ), + // u32 -> u16 is possible + ( + vec![DataType::Boolean, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt16], + Some(vec![DataType::Boolean, DataType::UInt32]), + ), + ]; + + for case in cases { + assert_eq!(maybe_data_types(&case.0, &case.1), case.2) + } + Ok(()) + } + + #[test] + fn test_coerce() -> Result<()> { + // create a schema + let schema = |t: Vec| { + Schema::new( + t.iter() + .enumerate() + .map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(), true)) + .collect(), + ) + }; + + // create a vector of expressions + let expressions = |t: Vec, schema| -> Result> { + t.iter() + .enumerate() + .map(|(i, t)| cast(col(&format!("c{}", i)), &schema, t.clone())) + .collect::>>() + }; + + // create a case: input + expected result + let case = + |observed: Vec, valid, expected: Vec| -> Result<_> { + let schema = schema(observed.clone()); + let expr = expressions(observed, schema.clone())?; + let expected = expressions(expected, schema.clone())?; + Ok((expr.clone(), schema, valid, expected)) + }; + + let cases = vec![ + // u16 -> u32 + case( + vec![DataType::UInt16], + Signature::Uniform(1, vec![DataType::UInt32]), + vec![DataType::UInt32], + )?, + // same type + case( + vec![DataType::UInt32, DataType::UInt32], + Signature::Uniform(2, vec![DataType::UInt32]), + vec![DataType::UInt32, DataType::UInt32], + )?, + case( + vec![DataType::UInt32], + Signature::Uniform(1, vec![DataType::Float32, DataType::Float64]), + vec![DataType::Float32], + )?, + // u32 -> f32 + case( + vec![DataType::UInt32, DataType::UInt32], + Signature::Variadic(vec![DataType::Float32]), + vec![DataType::Float32, DataType::Float32], + )?, + // u32 -> f32 + case( + vec![DataType::Float32, DataType::UInt32], + Signature::VariadicEqual, + vec![DataType::Float32, DataType::Float32], + )?, + ]; + + for case in cases { + let observed = format!("{:?}", coerce(&case.0, &case.1, &case.2)?); + let expected = format!("{:?}", case.3); + assert_eq!(observed, expected); + } + + // now cases that are expected to fail + let cases = vec![ + // we do not know how to cast bool to UInt16 => fail + case( + vec![DataType::Boolean], + Signature::Uniform(1, vec![DataType::UInt16]), + vec![], + )?, + // u32 and bool are not uniform + case( + vec![DataType::UInt32, DataType::Boolean], + Signature::VariadicEqual, + vec![], + )?, + // bool is not castable to u32 + case( + vec![DataType::Boolean, DataType::Boolean], + Signature::Variadic(vec![DataType::UInt32]), + vec![], + )?, + ]; + + for case in cases { + if let Ok(_) = coerce(&case.0, &case.1, &case.2) { + return Err(ExecutionError::General(format!( + "Error was expected in {:?}", + case + ))); + } + } + + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index ca5908748a13d..8da35fe9293b5 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -87,7 +87,7 @@ impl ScalarFunction { /// Scalar UDF Physical Expression pub struct ScalarFunctionExpr { - fun: Box, + fun: ScalarUdf, name: String, args: Vec>, return_type: DataType, @@ -108,7 +108,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: Box, + fun: ScalarUdf, args: Vec>, return_type: &DataType, ) -> Self { diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 1f70353902faa..82b11974bb6c6 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -30,7 +30,7 @@ use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; use crate::{ - execution::physical_plan::expressions::binary_operator_data_type, + execution::physical_plan::{expressions::binary_operator_data_type, functions}, sql::parser::FileType, }; use arrow::record_batch::RecordBatch; @@ -190,7 +190,14 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result { let expr = create_name(expr, input_schema)?; Ok(format!("CAST({} as {:?})", expr, data_type)) } - Expr::ScalarFunction { name, args, .. } => { + Expr::ScalarFunction { fun, args, .. } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun, names.join(","))) + } + Expr::ScalarUDF { name, args, .. } => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e, input_schema)?); @@ -258,9 +265,16 @@ pub enum Expr { /// Whether to put Nulls before all other data values nulls_first: bool, }, - /// scalar function + /// scalar function. ScalarFunction { - /// Name of the function + /// The function + fun: functions::ScalarFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// scalar udf. + ScalarUDF { + /// The function's name name: String, /// List of expressions to feed to the functions as arguments args: Vec, @@ -286,7 +300,14 @@ impl Expr { Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => l.get_datatype(), Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), + Expr::ScalarUDF { return_type, .. } => Ok(return_type.clone()), + Expr::ScalarFunction { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + functions::return_type(fun, &data_types) + } Expr::AggregateFunction { name, args, .. } => { match name.to_uppercase().as_str() { "MIN" | "MAX" => args[0].get_type(schema), @@ -360,6 +381,7 @@ impl Expr { }, Expr::Cast { expr, .. } => expr.nullable(input_schema), Expr::ScalarFunction { .. } => Ok(true), + Expr::ScalarUDF { .. } => Ok(true), Expr::AggregateFunction { .. } => Ok(true), Expr::Not(expr) => expr.nullable(input_schema), Expr::IsNull(_) => Ok(false), @@ -608,36 +630,42 @@ pub fn lit(n: T) -> Expr { /// Create an convenience function representing a unary scalar function macro_rules! unary_math_expr { - ($NAME:expr, $FUNC:ident) => { + ($ENUM:ident, $FUNC:ident) => { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { - scalar_function($NAME, vec![e], DataType::Float64) + Expr::ScalarFunction { + fun: functions::ScalarFunction::$ENUM, + args: vec![e], + } } }; } // generate methods for creating the supported unary math expressions -unary_math_expr!("sqrt", sqrt); -unary_math_expr!("sin", sin); -unary_math_expr!("cos", cos); -unary_math_expr!("tan", tan); -unary_math_expr!("asin", asin); -unary_math_expr!("acos", acos); -unary_math_expr!("atan", atan); -unary_math_expr!("floor", floor); -unary_math_expr!("ceil", ceil); -unary_math_expr!("round", round); -unary_math_expr!("trunc", trunc); -unary_math_expr!("abs", abs); -unary_math_expr!("signum", signum); -unary_math_expr!("exp", exp); -unary_math_expr!("log", ln); -unary_math_expr!("log2", log2); -unary_math_expr!("log10", log10); +unary_math_expr!(Sqrt, sqrt); +unary_math_expr!(Sin, sin); +unary_math_expr!(Cos, cos); +unary_math_expr!(Tan, tan); +unary_math_expr!(Asin, asin); +unary_math_expr!(Acos, acos); +unary_math_expr!(Atan, atan); +unary_math_expr!(Floor, floor); +unary_math_expr!(Ceil, ceil); +unary_math_expr!(Round, round); +unary_math_expr!(Trunc, trunc); +unary_math_expr!(Abs, abs); +unary_math_expr!(Signum, signum); +unary_math_expr!(Exp, exp); +unary_math_expr!(Log, ln); +unary_math_expr!(Log2, log2); +unary_math_expr!(Log10, log10); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { - scalar_function("length", vec![e], DataType::UInt32) + Expr::ScalarFunction { + fun: functions::ScalarFunction::Length, + args: vec![e], + } } /// Create an aggregate expression @@ -648,9 +676,9 @@ pub fn aggregate_expr(name: &str, expr: Expr) -> Expr { } } -/// Create an aggregate expression +/// call a scalar UDF pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { - Expr::ScalarFunction { + Expr::ScalarUDF { name: name.to_owned(), args: expr, return_type, @@ -688,7 +716,18 @@ impl fmt::Debug for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction { name, ref args, .. } => { + Expr::ScalarFunction { fun, ref args, .. } => { + write!(f, "{}(", fun)?; + for i in 0..args.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", args[i])?; + } + + write!(f, ")") + } + Expr::ScalarUDF { name, ref args, .. } => { write!(f, "{}(", name)?; for i in 0..args.len() { if i > 0 { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 65940baa34c38..056fc10d91fc1 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -64,7 +64,7 @@ where // modify `expressions` by introducing casts when necessary match expr { - Expr::ScalarFunction { name, .. } => { + Expr::ScalarUDF { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self.scalar_functions.lookup(name) { Some(func_meta) => { diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index b8e037f40ca9a..d2eaaa6cb50ca 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -62,6 +62,7 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< Expr::Sort { expr, .. } => expr_to_column_names(expr, accum), Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum), Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum), + Expr::ScalarUDF { args, .. } => exprlist_to_column_names(args, accum), Expr::Wildcard => Err(ExecutionError::General( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -194,6 +195,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::IsNull(e) => Ok(vec![e]), Expr::IsNotNull(e) => Ok(vec![e]), Expr::ScalarFunction { args, .. } => Ok(args.iter().collect()), + Expr::ScalarUDF { args, .. } => Ok(args.iter().collect()), Expr::AggregateFunction { args, .. } => Ok(args.iter().collect()), Expr::Cast { expr, .. } => Ok(vec![expr]), Expr::Column(_) => Ok(vec![]), @@ -219,9 +221,13 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result }), Expr::IsNull(_) => Ok(Expr::IsNull(Box::new(expressions[0].clone()))), Expr::IsNotNull(_) => Ok(Expr::IsNotNull(Box::new(expressions[0].clone()))), - Expr::ScalarFunction { + Expr::ScalarFunction { fun, .. } => Ok(Expr::ScalarFunction { + fun: fun.clone(), + args: expressions.clone(), + }), + Expr::ScalarUDF { name, return_type, .. - } => Ok(Expr::ScalarFunction { + } => Ok(Expr::ScalarUDF { name: name.clone(), return_type: return_type.clone(), args: expressions.clone(), diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index f627d05cf4319..5c8452d86dc65 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -17,6 +17,7 @@ //! SQL Query Planner (produces logical plan from SQL AST) +use std::str::FromStr; use std::sync::Arc; use crate::error::{ExecutionError, Result}; @@ -26,6 +27,7 @@ use crate::logicalplan::{ StringifiedPlan, }; use crate::{ + execution::physical_plan::functions, execution::physical_plan::udf::ScalarFunction, sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}, }; @@ -479,8 +481,20 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { } SQLExpr::Function(function) => { - //TODO: fix this hack let name: String = function.name.to_string(); + + // first, scalar built-in + if let Ok(fun) = functions::ScalarFunction::from_str(&name) { + let args = function + .args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + return Ok(Expr::ScalarFunction { fun, args }); + }; + + //TODO: fix this hack match name.to_lowercase().as_ref() { "min" | "max" | "sum" | "avg" => { let rex_args = function @@ -510,6 +524,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { args: rex_args, }) } + // finally, built-in scalar functions _ => match self.schema_provider.get_function_meta(&name) { Some(fm) => { let rex_args = function @@ -524,7 +539,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { .push(rex_args[i].cast_to(&fm.arg_types[i], schema)?); } - Ok(Expr::ScalarFunction { + Ok(Expr::ScalarUDF { name: name.clone(), args: safe_args, return_type: fm.return_type.clone(), @@ -592,7 +607,7 @@ mod tests { fn select_scalar_func_with_literal_no_relation() { quick_test( "SELECT sqrt(9)", - "Projection: sqrt(CAST(Int64(9) AS Float64))\ + "Projection: sqrt(Int64(9))\ \n EmptyRelation", ); } @@ -730,7 +745,7 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64))\ + let expected = "Projection: sqrt(#age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -738,7 +753,7 @@ mod tests { #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64)) AS square_people\ + let expected = "Projection: sqrt(#age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -904,8 +919,8 @@ mod tests { fn get_function_meta(&self, name: &str) -> Option> { match name { - "sqrt" => Some(Arc::new(ScalarFunction::new( - "sqrt", + "my_sqrt" => Some(Arc::new(ScalarFunction::new( + "my_sqrt", vec![DataType::Float64], DataType::Float64, Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 15120d754162e..daf37f409d1b5 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -201,6 +201,30 @@ fn csv_query_avg_sqrt() -> Result<()> { Ok(()) } +#[test] +fn csv_query_sqrt_f32() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // sqrt(f32)'s plan passes + let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql); + actual.sort(); + let expected = "0.6584408483418833".to_string(); + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_error() -> Result<()> { + // sin(utf8) should error + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT sin(c1) FROM aggregate_test_100"; + let plan = ctx.create_logical_plan(&sql); + assert!(plan.is_err()); + Ok(()) +} + // this query used to deadlock due to the call udf(udf()) #[test] fn csv_query_sqrt_sqrt() -> Result<()> {