Skip to content

Commit

Permalink
ARROW-9836: [Rust][DataFusion] Improve API for usage of UDFs
Browse files Browse the repository at this point in the history
See associated issue and document for details.

The gist is that currently, users call UDFs through

```
df.select(scalar_functions(“my_sqrt”, vec![col(“a”)], DataType::Float64))
```

and this PR proposes a change to

```
let functions = df.registry()?;

df.select(functions.udf(“my_sqrt”, vec![col(“a”)])?)
```

so that they do not have to remember the UDFs return type when using it (and a whole lot other things for us internally).

Closes #8032 from jorgecarleitao/registry

Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
jorgecarleitao authored and andygrove committed Sep 7, 2020
1 parent 5d66bc5 commit 4186a66
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 122 deletions.
17 changes: 16 additions & 1 deletion rust/datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::logical_plan::{Expr, LogicalPlan};
use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan};
use arrow::datatypes::Schema;
use std::sync::Arc;

Expand Down Expand Up @@ -188,4 +188,19 @@ pub trait DataFrame {
/// # }
/// ```
fn explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>>;

/// Return a `FunctionRegistry` used to plan udf's calls
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # fn main() -> Result<()> {
/// let mut ctx = ExecutionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
/// let f = df.registry();
/// // use f.udf("name", vec![...]) to use the udf
/// # Ok(())
/// # }
/// ```
fn registry(&self) -> &dyn FunctionRegistry;
}
43 changes: 35 additions & 8 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ use crate::datasource::parquet::ParquetTable;
use crate::datasource::TableProvider;
use crate::error::{ExecutionError, Result};
use crate::execution::dataframe_impl::DataFrameImpl;
use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder};
use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder};
use crate::optimizer::filter_push_down::FilterPushDown;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::{
filter_push_down::FilterPushDown, type_coercion::TypeCoercionRule,
};
use crate::optimizer::type_coercion::TypeCoercionRule;
use crate::physical_plan::common;
use crate::physical_plan::csv::CsvReadOptions;
use crate::physical_plan::merge::MergeExec;
Expand Down Expand Up @@ -294,7 +293,7 @@ impl ExecutionContext {
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = ProjectionPushDown::new().optimize(&plan)?;
let plan = FilterPushDown::new().optimize(&plan)?;
let plan = TypeCoercionRule::new(self).optimize(&plan)?;
let plan = TypeCoercionRule::new().optimize(&plan)?;
Ok(plan)
}

Expand Down Expand Up @@ -371,6 +370,11 @@ impl ExecutionContext {

Ok(())
}

/// get the registry, that allows to construct logical expressions of UDFs
pub fn registry(&self) -> &dyn FunctionRegistry {
&self.state
}
}

impl ScalarFunctionRegistry for ExecutionContext {
Expand Down Expand Up @@ -445,7 +449,27 @@ impl SchemaProvider for ExecutionContextState {
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarFunction>> {
self.scalar_functions
.get(name)
.and_then(|func| Some(Arc::new(func.as_ref().clone())))
.and_then(|func| Some(func.clone()))
}
}

impl FunctionRegistry for ExecutionContextState {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}

fn udf(&self, name: &str, args: Vec<Expr>) -> Result<Expr> {
let result = self.scalar_functions.get(name);
if result.is_none() {
Err(ExecutionError::General(
format!("There is no UDF named \"{}\" in the registry", name).to_string(),
))
} else {
Ok(Expr::ScalarUDF {
fun: result.unwrap().clone(),
args,
})
}
}
}

Expand All @@ -454,7 +478,7 @@ mod tests {

use super::*;
use crate::datasource::MemTable;
use crate::logical_plan::{aggregate_expr, col, scalar_function};
use crate::logical_plan::{aggregate_expr, col};
use crate::physical_plan::udf::ScalarUdf;
use crate::test;
use arrow::array::{ArrayRef, Int32Array};
Expand Down Expand Up @@ -997,13 +1021,16 @@ mod tests {

ctx.register_udf(my_add);

// from here on, we may be in a different scope. We would still like to be able
// to call UDFs.

let t = ctx.table("t")?;

let plan = LogicalPlanBuilder::from(&t.to_logical_plan())
.project(vec![
col("a"),
col("b"),
scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32),
ctx.registry().udf("my_add", vec![col("a"), col("b")])?,
])?
.build()?;

Expand Down
62 changes: 60 additions & 2 deletions rust/datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::arrow::record_batch::RecordBatch;
use crate::dataframe::*;
use crate::error::Result;
use crate::execution::context::{ExecutionContext, ExecutionContextState};
use crate::logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder};
use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder};
use arrow::datatypes::Schema;

/// Implementation of DataFrame API
Expand Down Expand Up @@ -124,6 +124,10 @@ impl DataFrame for DataFrameImpl {
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}

fn registry(&self) -> &dyn FunctionRegistry {
&self.ctx_state
}
}

#[cfg(test)]
Expand All @@ -132,7 +136,15 @@ mod tests {
use crate::datasource::csv::CsvReadOptions;
use crate::execution::context::ExecutionContext;
use crate::logical_plan::*;
use crate::test;
use crate::{
physical_plan::udf::{ScalarFunction, ScalarUdf},
test,
};
use arrow::{
array::{ArrayRef, Float64Array},
compute::add,
datatypes::DataType,
};

#[test]
fn select_columns() -> Result<()> {
Expand Down Expand Up @@ -232,6 +244,52 @@ mod tests {
Ok(())
}

#[test]
fn registry() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;

// declare the udf
let my_add: ScalarUdf = Arc::new(|args: &[ArrayRef]| {
let l = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let r = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
Ok(Arc::new(add(l, r)?))
});

let my_add = ScalarFunction::new(
"my_add",
vec![DataType::Float64],
DataType::Float64,
my_add,
);

// register the udf
ctx.register_udf(my_add);

// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;

let f = df.registry();

let df = df.select(vec![f.udf("my_add", vec![col("c12")])?])?;
let plan = df.to_logical_plan();

// build query using SQL
let sql_plan =
ctx.create_logical_plan("SELECT my_add(c12) FROM aggregate_test_100")?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

/// Compare the formatted string representation of two plans for equality
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
Expand Down
60 changes: 27 additions & 33 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
//! Logical query plans can then be optimized and executed directly, or translated into
//! physical query plans and executed.
use std::{fmt, sync::Arc};
use std::{collections::HashSet, fmt, sync::Arc};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};

use crate::datasource::csv::{CsvFile, CsvReadOptions};
use crate::datasource::parquet::ParquetTable;
use crate::datasource::TableProvider;
use crate::error::{ExecutionError, Result};
use crate::physical_plan::udf;
use crate::{
physical_plan::{
expressions::binary_operator_data_type, functions, type_coercion::can_coerce_from,
Expand Down Expand Up @@ -199,12 +200,12 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
}
Ok(format!("{}({})", fun, names.join(",")))
}
Expr::ScalarUDF { name, args, .. } => {
Expr::ScalarUDF { fun, args, .. } => {
let mut names = Vec::with_capacity(args.len());
for e in args {
names.push(create_name(e, input_schema)?);
}
Ok(format!("{}({})", name, names.join(",")))
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::AggregateFunction { name, args, .. } => {
let mut names = Vec::with_capacity(args.len());
Expand All @@ -226,7 +227,7 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result<Vec<Fi
}

/// Relation expression
#[derive(Clone, PartialEq)]
#[derive(Clone)]
pub enum Expr {
/// An aliased expression
Alias(Box<Expr>, String),
Expand Down Expand Up @@ -276,12 +277,10 @@ pub enum Expr {
},
/// scalar udf.
ScalarUDF {
/// The function's name
name: String,
/// The function
fun: Arc<udf::ScalarFunction>,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// The `DataType` the expression will yield
return_type: DataType,
},
/// aggregate function
AggregateFunction {
Expand All @@ -302,7 +301,7 @@ impl Expr {
Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()),
Expr::Literal(l) => l.get_datatype(),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarUDF { return_type, .. } => Ok(return_type.clone()),
Expr::ScalarUDF { fun, .. } => Ok(fun.return_type.clone()),
Expr::ScalarFunction { fun, args } => {
let data_types = args
.iter()
Expand Down Expand Up @@ -686,15 +685,6 @@ pub fn aggregate_expr(name: &str, expr: Expr) -> Expr {
}
}

/// call a scalar UDF
pub fn scalar_function(name: &str, expr: Vec<Expr>, return_type: DataType) -> Expr {
Expr::ScalarUDF {
name: name.to_owned(),
args: expr,
return_type,
}
}

impl fmt::Debug for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand Down Expand Up @@ -737,8 +727,8 @@ impl fmt::Debug for Expr {

write!(f, ")")
}
Expr::ScalarUDF { name, ref args, .. } => {
write!(f, "{}(", name)?;
Expr::ScalarUDF { fun, ref args, .. } => {
write!(f, "{}(", fun.name)?;
for i in 0..args.len() {
if i > 0 {
write!(f, ", ")?;
Expand Down Expand Up @@ -1038,6 +1028,15 @@ impl fmt::Debug for LogicalPlan {
}
}

/// A registry knows how to build logical expressions out of user-defined function' names
pub trait FunctionRegistry {
/// Set of all available udfs.
fn udfs(&self) -> HashSet<String>;

/// Constructs a logical expression with a call to the udf.
fn udf(&self, name: &str, args: Vec<Expr>) -> Result<Expr>;
}

/// Builder for logical plans
pub struct LogicalPlanBuilder {
plan: LogicalPlan,
Expand Down Expand Up @@ -1137,19 +1136,14 @@ impl LogicalPlanBuilder {
/// Apply a projection
pub fn project(&self, expr: Vec<Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
let projected_expr = if expr.contains(&Expr::Wildcard) {
let mut expr_vec = vec![];
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| expr_vec.push(col(input_schema.field(i).name())));
}
_ => expr_vec.push(expr[i].clone()),
});
expr_vec
} else {
expr.clone()
};
let mut projected_expr = vec![];
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| projected_expr.push(col(input_schema.field(i).name())));
}
_ => projected_expr.push(expr[i].clone()),
});

let schema =
Schema::new(exprlist_to_fields(&projected_expr, input_schema.as_ref())?);
Expand Down
Loading

0 comments on commit 4186a66

Please sign in to comment.