forked from apache/arrow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ARROW-9751: [Rust] [DataFusion] Allow UDFs to accept multiple data ty…
…pes per argument This PR aligns UDFs registration and declaration to be consistent with our built-in functions, so that we can leverage coercion rules on their arguments. For ease of use, this PR introduces a function `create_udf` that simplifies the creation of UDFs with a fixed signature and fixed return type, so that users have a simple interface to declare them. However, underneath, the UDFs have the same capabilities as built-in functions, in that they can be as generic as built-in functions (arbitrary types, etc.). Specific achievements of this PR: * Added example (120 LOC) of how to declare and register a UDF * Deprecated the type coercer optimizer, since it was causing logical schemas to become misaligned and cause our end-to-end tests to faail when implicit casting was required, and replaced it by what we already do for built-ins * Made UDFs use the same interfaces as built-in functions Note that this PR is built on top of apache#8032. Closes apache#7967 from jorgecarleitao/clean Authored-by: Jorge C. Leitao <[email protected]> Signed-off-by: Andy Grove <[email protected]>
- Loading branch information
1 parent
359da8e
commit 02f2e8a
Showing
13 changed files
with
427 additions
and
378 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::{ | ||
array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder}, | ||
datatypes::DataType, | ||
record_batch::RecordBatch, | ||
util::pretty, | ||
}; | ||
|
||
use datafusion::error::Result; | ||
use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*}; | ||
use std::sync::Arc; | ||
|
||
// create local execution context with an in-memory table | ||
fn create_context() -> Result<ExecutionContext> { | ||
use arrow::datatypes::{Field, Schema}; | ||
use datafusion::datasource::MemTable; | ||
// define a schema. | ||
let schema = Arc::new(Schema::new(vec![ | ||
Field::new("a", DataType::Float32, false), | ||
Field::new("b", DataType::Float64, false), | ||
])); | ||
|
||
// define data. | ||
let batch = RecordBatch::try_new( | ||
schema.clone(), | ||
vec![ | ||
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), | ||
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), | ||
], | ||
)?; | ||
|
||
// declare a new context. In spark API, this corresponds to a new spark SQLsession | ||
let mut ctx = ExecutionContext::new(); | ||
|
||
// declare a table in memory. In spark API, this corresponds to createDataFrame(...). | ||
let provider = MemTable::new(schema, vec![vec![batch]])?; | ||
ctx.register_table("t", Box::new(provider)); | ||
Ok(ctx) | ||
} | ||
|
||
/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b | ||
fn main() -> Result<()> { | ||
let mut ctx = create_context()?; | ||
|
||
// First, declare the actual implementation of the calculation | ||
let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { | ||
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: | ||
// 1. cast the values to the type we want | ||
// 2. perform the computation for every element in the array (using a loop or SIMD) | ||
// 3. construct the resulting array | ||
|
||
// this is guaranteed by DataFusion based on the function's signature. | ||
assert_eq!(args.len(), 2); | ||
|
||
// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! | ||
let base = &args[0] | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.expect("cast failed"); | ||
let exponent = &args[1] | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.expect("cast failed"); | ||
|
||
// this is guaranteed by DataFusion. We place it just to make it obvious. | ||
assert_eq!(exponent.len(), base.len()); | ||
|
||
// 2. Arrow's builder is used to construct an Arrow array. | ||
let mut builder = Float64Builder::new(base.len()); | ||
for index in 0..base.len() { | ||
// in arrow, any value can be null. | ||
// Here we decide to make our UDF to return null when either base or exponent is null. | ||
if base.is_null(index) || exponent.is_null(index) { | ||
builder.append_null()?; | ||
} else { | ||
// 3. computation. Since we do not have any SIMD `pow` operation at our hands, | ||
// we loop over each entry. Array's values are obtained via `.value(index)`. | ||
let value = base.value(index).powf(exponent.value(index)); | ||
builder.append_value(value)?; | ||
} | ||
} | ||
Ok(Arc::new(builder.finish())) | ||
}); | ||
|
||
// Next: | ||
// * give it a name so that it shows nicely when the plan is printed | ||
// * declare what input it expects | ||
// * declare its return type | ||
let pow = create_udf( | ||
"pow", | ||
// expects two f64 | ||
vec![DataType::Float64, DataType::Float64], | ||
// returns f64 | ||
Arc::new(DataType::Float64), | ||
pow, | ||
); | ||
|
||
// finally, register the UDF | ||
ctx.register_udf(pow); | ||
|
||
// at this point, we can use it. Note that the code below can be in a | ||
// scope on which we do not have access to `pow`. | ||
|
||
// get a DataFrame from the context | ||
let df = ctx.table("t")?; | ||
|
||
// get the udf registry. | ||
let f = df.registry(); | ||
|
||
// equivalent to `'SELECT pow(a, b) FROM t'` | ||
let df = df.select(vec![f.udf("pow", vec![col("a"), col("b")])?])?; | ||
|
||
// note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature. | ||
|
||
// execute the query | ||
let results = df.collect()?; | ||
|
||
// print the results | ||
pretty::print_batches(&results)?; | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.