diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index 978c50db67eec..b839c48f0bfca 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -19,7 +19,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; -use crate::logical_plan::{Expr, LogicalPlan}; +use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan}; use arrow::datatypes::Schema; use std::sync::Arc; @@ -188,4 +188,19 @@ pub trait DataFrame { /// # } /// ``` fn explain(&self, verbose: bool) -> Result>; + + /// Return a `FunctionRegistry` used to plan udf's calls + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # fn main() -> Result<()> { + /// let mut ctx = ExecutionContext::new(); + /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; + /// let f = df.registry(); + /// // use f.udf("name", vec![...]) to use the udf + /// # Ok(()) + /// # } + /// ``` + fn registry(&self) -> &dyn FunctionRegistry; } diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 2933d1a4865aa..85d664c018d69 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -36,12 +36,11 @@ use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; use crate::execution::dataframe_impl::DataFrameImpl; -use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; +use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder}; +use crate::optimizer::filter_push_down::FilterPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; -use crate::optimizer::{ - filter_push_down::FilterPushDown, type_coercion::TypeCoercionRule, -}; +use crate::optimizer::type_coercion::TypeCoercionRule; use crate::physical_plan::common; use crate::physical_plan::csv::CsvReadOptions; use crate::physical_plan::merge::MergeExec; @@ -294,7 +293,7 @@ impl ExecutionContext { pub fn optimize(&self, plan: &LogicalPlan) -> Result { let plan = ProjectionPushDown::new().optimize(&plan)?; let plan = FilterPushDown::new().optimize(&plan)?; - let plan = TypeCoercionRule::new(self).optimize(&plan)?; + let plan = TypeCoercionRule::new().optimize(&plan)?; Ok(plan) } @@ -371,6 +370,11 @@ impl ExecutionContext { Ok(()) } + + /// get the registry, that allows to construct logical expressions of UDFs + pub fn registry(&self) -> &dyn FunctionRegistry { + &self.state + } } impl ScalarFunctionRegistry for ExecutionContext { @@ -445,7 +449,27 @@ impl SchemaProvider for ExecutionContextState { fn get_function_meta(&self, name: &str) -> Option> { self.scalar_functions .get(name) - .and_then(|func| Some(Arc::new(func.as_ref().clone()))) + .and_then(|func| Some(func.clone())) + } +} + +impl FunctionRegistry for ExecutionContextState { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str, args: Vec) -> Result { + let result = self.scalar_functions.get(name); + if result.is_none() { + Err(ExecutionError::General( + format!("There is no UDF named \"{}\" in the registry", name).to_string(), + )) + } else { + Ok(Expr::ScalarUDF { + fun: result.unwrap().clone(), + args, + }) + } } } @@ -454,7 +478,7 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::logical_plan::{aggregate_expr, col, scalar_function}; + use crate::logical_plan::{aggregate_expr, col}; use crate::physical_plan::udf::ScalarUdf; use crate::test; use arrow::array::{ArrayRef, Int32Array}; @@ -997,13 +1021,16 @@ mod tests { ctx.register_udf(my_add); + // from here on, we may be in a different scope. We would still like to be able + // to call UDFs. + let t = ctx.table("t")?; let plan = LogicalPlanBuilder::from(&t.to_logical_plan()) .project(vec![ col("a"), col("b"), - scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), + ctx.registry().udf("my_add", vec![col("a"), col("b")])?, ])? .build()?; diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 302843a674546..a343d2917c1e1 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -23,7 +23,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::dataframe::*; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; -use crate::logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; +use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder}; use arrow::datatypes::Schema; /// Implementation of DataFrame API @@ -124,6 +124,10 @@ impl DataFrame for DataFrameImpl { .build()?; Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } + + fn registry(&self) -> &dyn FunctionRegistry { + &self.ctx_state + } } #[cfg(test)] @@ -132,7 +136,15 @@ mod tests { use crate::datasource::csv::CsvReadOptions; use crate::execution::context::ExecutionContext; use crate::logical_plan::*; - use crate::test; + use crate::{ + physical_plan::udf::{ScalarFunction, ScalarUdf}, + test, + }; + use arrow::{ + array::{ArrayRef, Float64Array}, + compute::add, + datatypes::DataType, + }; #[test] fn select_columns() -> Result<()> { @@ -232,6 +244,52 @@ mod tests { Ok(()) } + #[test] + fn registry() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + + // declare the udf + let my_add: ScalarUdf = Arc::new(|args: &[ArrayRef]| { + let l = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + Ok(Arc::new(add(l, r)?)) + }); + + let my_add = ScalarFunction::new( + "my_add", + vec![DataType::Float64], + DataType::Float64, + my_add, + ); + + // register the udf + ctx.register_udf(my_add); + + // build query with a UDF using DataFrame API + let df = ctx.table("aggregate_test_100")?; + + let f = df.registry(); + + let df = df.select(vec![f.udf("my_add", vec![col("c12")])?])?; + let plan = df.to_logical_plan(); + + // build query using SQL + let sql_plan = + ctx.create_logical_plan("SELECT my_add(c12) FROM aggregate_test_100")?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + /// Compare the formatted string representation of two plans for equality fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2)); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 5c57087f49a86..b2736257322fb 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -21,7 +21,7 @@ //! Logical query plans can then be optimized and executed directly, or translated into //! physical query plans and executed. -use std::{fmt, sync::Arc}; +use std::{collections::HashSet, fmt, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -29,6 +29,7 @@ use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; +use crate::physical_plan::udf; use crate::{ physical_plan::{ expressions::binary_operator_data_type, functions, type_coercion::can_coerce_from, @@ -199,12 +200,12 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result { } Ok(format!("{}({})", fun, names.join(","))) } - Expr::ScalarUDF { name, args, .. } => { + Expr::ScalarUDF { fun, args, .. } => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e, input_schema)?); } - Ok(format!("{}({})", name, names.join(","))) + Ok(format!("{}({})", fun.name, names.join(","))) } Expr::AggregateFunction { name, args, .. } => { let mut names = Vec::with_capacity(args.len()); @@ -226,7 +227,7 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result, String), @@ -276,12 +277,10 @@ pub enum Expr { }, /// scalar udf. ScalarUDF { - /// The function's name - name: String, + /// The function + fun: Arc, /// List of expressions to feed to the functions as arguments args: Vec, - /// The `DataType` the expression will yield - return_type: DataType, }, /// aggregate function AggregateFunction { @@ -302,7 +301,7 @@ impl Expr { Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => l.get_datatype(), Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::ScalarUDF { return_type, .. } => Ok(return_type.clone()), + Expr::ScalarUDF { fun, .. } => Ok(fun.return_type.clone()), Expr::ScalarFunction { fun, args } => { let data_types = args .iter() @@ -686,15 +685,6 @@ pub fn aggregate_expr(name: &str, expr: Expr) -> Expr { } } -/// call a scalar UDF -pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { - Expr::ScalarUDF { - name: name.to_owned(), - args: expr, - return_type, - } -} - impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -737,8 +727,8 @@ impl fmt::Debug for Expr { write!(f, ")") } - Expr::ScalarUDF { name, ref args, .. } => { - write!(f, "{}(", name)?; + Expr::ScalarUDF { fun, ref args, .. } => { + write!(f, "{}(", fun.name)?; for i in 0..args.len() { if i > 0 { write!(f, ", ")?; @@ -1038,6 +1028,15 @@ impl fmt::Debug for LogicalPlan { } } +/// A registry knows how to build logical expressions out of user-defined function' names +pub trait FunctionRegistry { + /// Set of all available udfs. + fn udfs(&self) -> HashSet; + + /// Constructs a logical expression with a call to the udf. + fn udf(&self, name: &str, args: Vec) -> Result; +} + /// Builder for logical plans pub struct LogicalPlanBuilder { plan: LogicalPlan, @@ -1137,19 +1136,14 @@ impl LogicalPlanBuilder { /// Apply a projection pub fn project(&self, expr: Vec) -> Result { let input_schema = self.plan.schema(); - let projected_expr = if expr.contains(&Expr::Wildcard) { - let mut expr_vec = vec![]; - (0..expr.len()).for_each(|i| match &expr[i] { - Expr::Wildcard => { - (0..input_schema.fields().len()) - .for_each(|i| expr_vec.push(col(input_schema.field(i).name()))); - } - _ => expr_vec.push(expr[i].clone()), - }); - expr_vec - } else { - expr.clone() - }; + let mut projected_expr = vec![]; + (0..expr.len()).for_each(|i| match &expr[i] { + Expr::Wildcard => { + (0..input_schema.fields().len()) + .for_each(|i| projected_expr.push(col(input_schema.field(i).name()))); + } + _ => projected_expr.push(expr[i].clone()), + }); let schema = Schema::new(exprlist_to_fields(&projected_expr, input_schema.as_ref())?); diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 76ea71107b651..9e74b6934282e 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -22,34 +22,24 @@ use arrow::datatypes::Schema; -use crate::error::{ExecutionError, Result}; +use crate::error::Result; use crate::logical_plan::Expr; use crate::logical_plan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use crate::physical_plan::{ - expressions::numerical_coercion, udf::ScalarFunctionRegistry, -}; +use crate::physical_plan::expressions::numerical_coercion; use utils::optimize_explain; /// Optimizer that applies coercion rules to expressions in the logical plan. /// /// This optimizer does not alter the structure of the plan, it only changes expressions on it. -pub struct TypeCoercionRule<'a, P> -where - P: ScalarFunctionRegistry, -{ - scalar_functions: &'a P, -} +pub struct TypeCoercionRule {} -impl<'a, P> TypeCoercionRule<'a, P> -where - P: ScalarFunctionRegistry, -{ +impl TypeCoercionRule { /// Create a new type coercion optimizer rule using meta-data about registered /// scalar functions - pub fn new(scalar_functions: &'a P) -> Self { - Self { scalar_functions } + pub fn new() -> Self { + Self {} } /// Rewrite an expression to include explicit CAST operations when required @@ -64,32 +54,22 @@ where // modify `expressions` by introducing casts when necessary match expr { - Expr::ScalarUDF { name, .. } => { + Expr::ScalarUDF { fun, .. } => { // cast the inputs of scalar functions to the appropriate type where possible - match self.scalar_functions.lookup(name) { - Some(func_meta) => { - for i in 0..expressions.len() { - let actual_type = expressions[i].get_type(schema)?; - let required_type = &func_meta.arg_types[i]; - if &actual_type != required_type { - // attempt to coerce using numerical coercion - // todo: also try string coercion. - if let Some(cast_to_type) = - numerical_coercion(&actual_type, required_type) - { - expressions[i] = - expressions[i].cast_to(&cast_to_type, schema)? - }; - // not possible: do nothing and let the plan fail with a clear error message - }; - } - } - _ => { - return Err(ExecutionError::General(format!( - "Invalid scalar function {}", - name - ))) - } + for i in 0..expressions.len() { + let actual_type = expressions[i].get_type(schema)?; + let required_type = &fun.arg_types[i]; + if &actual_type != required_type { + // attempt to coerce using numerical coercion + // todo: also try string coercion. + if let Some(cast_to_type) = + numerical_coercion(&actual_type, required_type) + { + expressions[i] = + expressions[i].cast_to(&cast_to_type, schema)? + }; + // not possible: do nothing and let the plan fail with a clear error message + }; } } _ => {} @@ -98,10 +78,7 @@ where } } -impl<'a, P> OptimizerRule for TypeCoercionRule<'a, P> -where - P: ScalarFunctionRegistry, -{ +impl OptimizerRule for TypeCoercionRule { fn optimize(&mut self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Explain { diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 1c5f5104e3b41..0e12f00c30e23 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -225,11 +225,8 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result fun: fun.clone(), args: expressions.clone(), }), - Expr::ScalarUDF { - name, return_type, .. - } => Ok(Expr::ScalarUDF { - name: name.clone(), - return_type: return_type.clone(), + Expr::ScalarUDF { fun, .. } => Ok(Expr::ScalarUDF { + fun: fun.clone(), args: expressions.clone(), }), Expr::AggregateFunction { name, .. } => Ok(Expr::AggregateFunction { diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index a66eb67c2c2f9..d5fd705c6f300 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -368,32 +368,22 @@ impl DefaultPhysicalPlanner { .collect::>>()?; functions::create_physical_expr(fun, &physical_args, input_schema) } - Expr::ScalarUDF { - name, - args, - return_type, - } => match ctx_state.scalar_functions.get(name) { - Some(f) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(self.create_physical_expr( - e, - input_schema, - ctx_state, - )?); - } - Ok(Arc::new(ScalarFunctionExpr::new( - name, - f.fun.clone(), - physical_args, - return_type, - ))) + Expr::ScalarUDF { fun, args } => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr( + e, + input_schema, + ctx_state, + )?); } - _ => Err(ExecutionError::General(format!( - "Invalid scalar function '{:?}'", - name - ))), - }, + Ok(Arc::new(ScalarFunctionExpr::new( + &fun.name, + fun.fun.clone(), + physical_args, + &fun.return_type, + ))) + } other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index ed00f3e19f29a..e1a36f40d4d52 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -525,7 +525,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { args: rex_args, }) } - // finally, built-in scalar functions + // finally, user-defined functions _ => match self.schema_provider.get_function_meta(&name) { Some(fm) => { let rex_args = function @@ -541,9 +541,8 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { } Ok(Expr::ScalarUDF { - name: name.clone(), + fun: fm.clone(), args: safe_args, - return_type: fm.return_type.clone(), }) } _ => Err(ExecutionError::General(format!(