diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 31595c980a30d..f1249a2fabdd9 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false } prost-derive = { version = "0.11", default-features = false } serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.82" +tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } tonic = "0.9" url = "2.2" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index df6ad5a467b60..02dd9c4173251 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -57,6 +57,7 @@ cargo run --example csv_sql - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) +- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs new file mode 100644 index 0000000000000..5af2c6d189533 --- /dev/null +++ b/datafusion-examples/examples/simple_udwf.rs @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use arrow_schema::DataType; +use datafusion::datasource::file_format::options::CsvReadOptions; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDF}; + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = format!("datafusion/core/tests/data/cars.csv"); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + + // register the window function with DataFusion so wecan call it + ctx.register_udwf(smooth_it()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a seprate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be orderd by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window. This *requires* that `evaluate_all` or + // + // `ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // todo show how to run dataframe API as well + + Ok(()) +} +fn smooth_it() -> WindowUDF { + WindowUDF { + name: String::from("smooth_it"), + // it will take 1 arguments -- the column to smooth + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + return_type: Arc::new(return_type), + partition_evaluator_factory: Arc::new(make_partition_evaluator), + } +} + +/// Compute the return type of the smooth_it window function given +/// arguments of `arg_types`. +fn return_type(arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(DataFusionError::Plan(format!( + "my_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ))); + } + Ok(Arc::new(arg_types[0].clone())) +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// These different evaluation methods are called depending on the various settings of WindowUDF +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + println!("evaluate_inside_range(). range: {range:#?}, values: {values:#?}"); + + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 8d9fdc2e537cd..f0b3f2aae7019 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -32,7 +32,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, StringifiedPlan, UserDefinedLogicalNode, + DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -786,6 +786,20 @@ impl SessionContext { .insert(f.name.clone(), Arc::new(f)); } + /// Registers an window UDF within this context. + /// + /// Note in SQL queries, window function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"` + /// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` + pub fn register_udwf(&self, f: WindowUDF) { + self.state + .write() + .window_functions + .insert(f.name.clone(), Arc::new(f)); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1279,6 +1293,10 @@ impl FunctionRegistry for SessionContext { fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } + + fn udwf(&self, name: &str) -> Result> { + self.state.read().udwf(name) + } } /// A planner used to add extensions to DataFusion logical and physical plans. @@ -1329,6 +1347,8 @@ pub struct SessionState { scalar_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, + /// Window functions registered in the context + window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, /// Session configuration @@ -1423,6 +1443,7 @@ impl SessionState { catalog_list, scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), config, execution_props: ExecutionProps::new(), @@ -1899,6 +1920,11 @@ impl SessionState { &self.aggregate_functions } + /// Return reference to window functions + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + /// Return [SerializerRegistry] for extensions pub fn serializer_registry(&self) -> Arc { self.serializer_registry.clone() @@ -1932,6 +1958,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; @@ -1979,6 +2009,16 @@ impl FunctionRegistry for SessionState { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Plan(format!( + "There is no UDWF named \"{name}\" in the registry" + )) + }) + } } impl OptimizerConfig for SessionState { @@ -2012,6 +2052,7 @@ impl From<&SessionState> for TaskContext { state.config.clone(), state.scalar_functions.clone(), state.aggregate_functions.clone(), + state.window_functions.clone(), state.runtime_env.clone(), ) } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 0cd6a746dd353..2d36b31f00e5c 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -26,12 +26,12 @@ use crate::physical_plan::{ udaf, ExecutionPlan, PhysicalExpr, }; use arrow::datatypes::Schema; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ window_function::{BuiltInWindowFunction, WindowFunction}, - WindowFrame, + PartitionEvaluator, WindowFrame, WindowUDF, }; use datafusion_physical_expr::{ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, @@ -92,6 +92,12 @@ pub fn create_window_expr( aggregate, ) } + WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), }) } @@ -206,6 +212,71 @@ fn create_built_in_window_expr( }) } +/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function +fn create_udwf_window_expr( + fun: &Arc, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + // need to get the types into an owned vec for some reason + let input_types: Vec<_> = args + .iter() + .map(|arg| arg.data_type(input_schema).map(|dt| dt.clone())) + .collect::>()?; + + // figure out the output type + let data_type = (fun.return_type)(&input_types)?; + Ok(Arc::new(WindowUDFExpr { + fun: Arc::clone(fun), + args: args.to_vec(), + name, + data_type, + })) +} + +// Implement BuiltInWindowFunctionExpr for WindowUDF +#[derive(Clone, Debug)] +struct WindowUDFExpr { + fun: Arc, + args: Vec>, + /// Display name + name: String, + /// result type + data_type: Arc, +} + +impl BuiltInWindowFunctionExpr for WindowUDFExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn field(&self) -> Result { + let nullable = false; + Ok(Field::new( + &self.name, + self.data_type.as_ref().clone(), + nullable, + )) + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn create_evaluator(&self) -> Result> { + (self.fun.partition_evaluator_factory)() + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + None + } +} + pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, diff --git a/datafusion/core/tests/data/cars.csv b/datafusion/core/tests/data/cars.csv new file mode 100644 index 0000000000000..bc40f3b01e7a5 --- /dev/null +++ b/datafusion/core/tests/data/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs new file mode 100644 index 0000000000000..ab6f51c47ba78 --- /dev/null +++ b/datafusion/core/tests/user_defined/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Tests for User Defined Aggregate Functions +mod user_defined_aggregates; + +/// Tests for User Defined Plans +mod user_defined_plan; + +/// Tests for User Defined Window Functions +mod user_defined_window_functions; diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs similarity index 100% rename from datafusion/core/tests/user_defined_aggregates.rs rename to datafusion/core/tests/user_defined/user_defined_aggregates.rs diff --git a/datafusion/core/tests/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs similarity index 100% rename from datafusion/core/tests/user_defined_plan.rs rename to datafusion/core/tests/user_defined/user_defined_plan.rs diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs new file mode 100644 index 0000000000000..a9cd8a993158f --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -0,0 +1,548 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains end to end tests of creating +//! user defined window functions + +use std::{ + ops::Range, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use arrow::array::AsArray; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; +use arrow_schema::DataType; +use datafusion::{assert_batches_eq, prelude::SessionContext}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + function::PartitionEvaluatorFactory, window_state::WindowAggState, + PartitionEvaluator, ReturnTypeFunction, Signature, Volatility, WindowUDF, +}; + +/// A query with a window function evaluated over the entire partition +const UNBOUNDED_WINDOW_QUERY: &'static str = "SELECT x, y, val, \ + odd_counter(val) OVER (PARTITION BY x ORDER BY y) \ + from t ORDER BY x, y"; + +/// A query with a window function evaluated over a moving window +const BOUNDED_WINDOW_QUERY: &'static str = + "SELECT x, y, val, \ + odd_counter(val) OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ + from t ORDER BY x, y"; + +/// Test to show the contents of the setup +#[tokio::test] +async fn test_setup() { + let test_state = TestState::new(); + let TestContext { ctx, test_state: _ } = TestContext::new(test_state); + + let sql = "SELECT * from t order by x, y"; + let expected = vec![ + "+---+---+-----+", + "| x | y | val |", + "+---+---+-----+", + "| 1 | a | 0 |", + "| 1 | b | 1 |", + "| 1 | c | 2 |", + "| 2 | d | 3 |", + "| 2 | e | 4 |", + "| 2 | f | 5 |", + "| 2 | g | 6 |", + "| 2 | h | 6 |", + "| 2 | i | 6 |", + "| 2 | j | 6 |", + "+---+---+-----+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +/// Basic user defined window function +#[tokio::test] +async fn test_udwf() { + let test_state = TestState::new(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 2 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + // evaluated on three distinct batches + assert_eq!(test_state.evaluate_all_called(), 2); +} + +/// Basic user defined window function with bounded window +#[tokio::test] +async fn test_udwf_bounded_window_ignores_frame() { + let test_state = TestState::new(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + // Since the UDWF doesn't say it needs the window frame, the frame is ignored + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 2 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 2); +} + +/// Basic user defined window function with bounded window +#[tokio::test] +async fn test_udwf_bounded_window() { + let test_state = TestState::new().with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 1 |", + "| 2 | g | 6 | 1 |", + "| 2 | h | 6 | 0 |", + "| 2 | i | 6 | 0 |", + "| 2 | j | 6 | 0 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // Evaluate is called for each input rows + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// Basic stateful user defined window function +#[tokio::test] +async fn test_stateful_udwf() { + let test_state = TestState::new() + .with_supports_bounded_execution() + .with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 0 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 1 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.update_state_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// Basic stateful user defined window function with bounded window +#[tokio::test] +async fn test_stateful_udwf_bounded_window() { + let test_state = TestState::new() + .with_supports_bounded_execution() + .with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 1 |", + "| 2 | g | 6 | 1 |", + "| 2 | h | 6 | 0 |", + "| 2 | i | 6 | 0 |", + "| 2 | j | 6 | 0 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // Evaluate and update_state is called for each input row + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.update_state_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// user defined window function using rank +#[tokio::test] +async fn test_udwf_query_include_rank() { + let test_state = TestState::new().with_include_rank(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 3 |", + "| 1 | b | 1 | 2 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 7 |", + "| 2 | e | 4 | 6 |", + "| 2 | f | 5 | 5 |", + "| 2 | g | 6 | 4 |", + "| 2 | h | 6 | 3 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 1 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 0); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_all_with_rank_called(), 2); +} + +/// user defined window function with bounded window using rank +#[tokio::test] +async fn test_udwf_bounded_query_include_rank() { + let test_state = TestState::new().with_include_rank(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------+", + "| x | y | val | odd_counter(t.val) |", + "+---+---+-----+--------------------+", + "| 1 | a | 0 | 3 |", + "| 1 | b | 1 | 2 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 7 |", + "| 2 | e | 4 | 6 |", + "| 2 | f | 5 | 5 |", + "| 2 | g | 6 | 4 |", + "| 2 | h | 6 | 3 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 1 |", + "+---+---+-----+--------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 0); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_all_with_rank_called(), 2); +} + +async fn execute(ctx: &SessionContext, sql: &str) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Returns an context with a table "t" and the "first" and "time_sum" +/// aggregate functions registered. +/// +/// "t" contains this data: +/// +/// ```text +/// x | y | val +/// 1 | a | 0 +/// 1 | b | 1 +/// 1 | c | 2 +/// 2 | d | 3 +/// 2 | e | 4 +/// 2 | f | 5 +/// 2 | g | 6 +/// 2 | h | 6 +/// 2 | i | 6 +/// 2 | j | 6 +/// ``` +struct TestContext { + ctx: SessionContext, + test_state: Arc, +} + +impl TestContext { + fn new(test_state: TestState) -> Self { + let test_state = Arc::new(test_state); + let x = Int64Array::from(vec![1, 1, 1, 2, 2, 2, 2, 2, 2, 2]); + let y = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]); + let val = Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 6, 6, 6]); + + let batch = RecordBatch::try_from_iter(vec![ + ("x", Arc::new(x) as _), + ("y", Arc::new(y) as _), + ("val", Arc::new(val) as _), + ]) + .unwrap(); + + let mut ctx = SessionContext::new(); + + ctx.register_batch("t", batch).unwrap(); + + // Tell DataFusion about the window function + OddCounter::register(&mut ctx, Arc::clone(&test_state)); + + Self { ctx, test_state } + } +} + +#[derive(Debug, Default)] +struct TestState { + /// How many times was `evaluate_all` called? + evaluate_all_called: AtomicUsize, + /// How many times was `evaluate` called? + evaluate_called: AtomicUsize, + /// How many times was `update_state` called? + update_state_called: AtomicUsize, + /// How many times was `evaluate_all_with_rank` called? + evaluate_all_with_rank_called: AtomicUsize, + /// should the functions say they use the window frame? + uses_window_frame: bool, + /// should the functions say they support bounded execution + supports_bounded_execution: bool, + /// should the functions they need include rank + include_rank: bool, +} + +impl TestState { + fn new() -> Self { + Default::default() + } + + /// Set that this function should use the window frame + fn with_uses_window_frame(mut self) -> Self { + self.uses_window_frame = true; + self + } + + /// Set that this function should use bounded / stateful execution + fn with_supports_bounded_execution(mut self) -> Self { + self.supports_bounded_execution = true; + self + } + + /// Set that this function should include rank + fn with_include_rank(mut self) -> Self { + self.include_rank = true; + self + } + + /// return the evaluate_all_called counter + fn evaluate_all_called(&self) -> usize { + self.evaluate_all_called.load(Ordering::SeqCst) + } + + /// update the evaluate_all_called counter + fn inc_evaluate_all_called(&self) { + self.evaluate_all_called.fetch_add(1, Ordering::SeqCst); + } + + /// return the evaluate_called counter + fn evaluate_called(&self) -> usize { + self.evaluate_called.load(Ordering::SeqCst) + } + + /// update the evaluate_called counter + fn inc_evaluate_called(&self) { + self.evaluate_called.fetch_add(1, Ordering::SeqCst); + } + + /// return the update_state_called counter + fn update_state_called(&self) -> usize { + self.update_state_called.load(Ordering::SeqCst) + } + + /// update the update_state_called counter + fn inc_update_state_called(&self) { + self.update_state_called.fetch_add(1, Ordering::SeqCst); + } + + /// return the evaluate_all_with_rank_called counter + fn evaluate_all_with_rank_called(&self) -> usize { + self.evaluate_all_with_rank_called.load(Ordering::SeqCst) + } + + /// update the evaluate_all_with_rank_called counter + fn inc_evaluate_all_with_rank_called(&self) { + self.evaluate_all_with_rank_called + .fetch_add(1, Ordering::SeqCst); + } +} + +// Partition Evaluator that counts the number of odd numbers in the window frame using evaluate +#[derive(Debug)] +struct OddCounter { + test_state: Arc, +} + +impl OddCounter { + fn new(test_state: Arc) -> Self { + Self { test_state } + } + + fn register(ctx: &mut SessionContext, test_state: Arc) { + let name = "odd_counter"; + let volatility = Volatility::Immutable; + + let signature = Signature::exact(vec![DataType::Int64], volatility); + + let return_type = Arc::new(DataType::Int64); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::clone(&return_type))); + + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); + + ctx.register_udwf(WindowUDF::new( + name, + &signature, + &return_type, + &partition_evaluator_factory, + )) + } +} + +impl PartitionEvaluator for OddCounter { + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + println!("evaluate, values: {values:#?}, range: {range:?}"); + + self.test_state.inc_evaluate_called(); + let values: &Int64Array = values.get(0).unwrap().as_primitive(); + let values = values.slice(range.start, range.len()); + let scalar = ScalarValue::Int64(Some(odd_count(&values))); + Ok(scalar) + } + + fn evaluate_all( + &mut self, + values: &[arrow_array::ArrayRef], + num_rows: usize, + ) -> Result { + println!("evaluate_all, values: {values:#?}, num_rows: {num_rows}"); + + self.test_state.inc_evaluate_all_called(); + Ok(odd_count_arr( + values.get(0).unwrap().as_primitive(), + num_rows, + )) + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + self.test_state.inc_evaluate_all_with_rank_called(); + println!("evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}"); + // when evaluating with ranks, just return the inverse rank instead + let array: Int64Array = ranks_in_partition + .iter() + // cloned range is an iterator + .cloned() + .flatten() + .map(|v| (num_rows - v) as i64) + .collect(); + Ok(Arc::new(array)) + } + + fn update_state( + &mut self, + _state: &WindowAggState, + _idx: usize, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + self.test_state.inc_update_state_called(); + Ok(()) + } + + fn supports_bounded_execution(&self) -> bool { + self.test_state.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.test_state.uses_window_frame + } + + fn include_rank(&self) -> bool { + self.test_state.include_rank + } +} + +/// returns the number of entries in arr that are odd +fn odd_count(arr: &Int64Array) -> i64 { + arr.iter().filter_map(|x| x.map(|x| x % 2)).sum() +} + +/// returns an array of num_rows that has the number of odd values in `arr` +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { + let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); + Arc::new(array) +} diff --git a/datafusion/core/tests/user_defined_integration.rs b/datafusion/core/tests/user_defined_integration.rs new file mode 100644 index 0000000000000..4f9cc89529adb --- /dev/null +++ b/datafusion/core/tests/user_defined_integration.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Run all tests that are found in the `user_defined` directory +mod user_defined; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); +} diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index ef06c74cc2923..9ba487e715b3b 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -18,7 +18,7 @@ //! FunctionRegistry trait use datafusion_common::Result; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode}; +use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::{collections::HashSet, sync::Arc}; /// A registry knows how to build logical expressions out of user-defined function' names @@ -31,6 +31,9 @@ pub trait FunctionRegistry { /// Returns a reference to the udaf named `name`. fn udaf(&self, name: &str) -> Result>; + + /// Returns a reference to the udwf named `name`. + fn udwf(&self, name: &str) -> Result>; } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index ca1bc9369e351..6aacf2de5f84d 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -24,7 +24,7 @@ use datafusion_common::{ config::{ConfigOptions, Extensions}, DataFusionError, Result, }; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, @@ -48,18 +48,23 @@ pub struct TaskContext { scalar_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, + /// Window functions associated with this task context + window_functions: HashMap>, /// Runtime environment associated with this task context runtime: Arc, } impl TaskContext { - /// Create a new task context instance + /// Create a new [`TaskContext`] instance. + /// + /// Most users will use [`SessionContext::task_ctx()`] to create [`TaskContext`]s pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, + window_functions: HashMap>, runtime: Arc, ) -> Self { Self { @@ -68,6 +73,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, } } @@ -92,6 +98,7 @@ impl TaskContext { config.set(&k, &v)?; } let session_config = SessionConfig::from(config); + let window_functions = HashMap::new(); Ok(Self::new( Some(task_id), @@ -99,6 +106,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, )) } @@ -153,6 +161,16 @@ impl FunctionRegistry for TaskContext { )) }) } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDWF named \"{name}\" in the TaskContext" + )) + }) + } } #[cfg(test)] diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bd242c493e432..795eb6b50e5d5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,7 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, Signature}; +use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; use arrow::datatypes::DataType; use datafusion_common::utils::datafusion_strsim; @@ -45,6 +45,11 @@ pub type ReturnTypeFunction = pub type AccumulatorFunctionImplementation = Arc Result> + Send + Sync>; +/// Factory that creates a PartitionEvaluator for the given window +/// function +pub type PartitionEvaluatorFactory = + Arc Result> + Send + Sync>; + /// Factory that returns the types used by an aggregator to serialize /// its state, given its return datatype. pub type StateTypeFunction = diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index ccb972887778a..b80d0eea87452 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -49,6 +49,7 @@ pub mod tree_node; pub mod type_coercion; mod udaf; mod udf; +mod udwf; pub mod utils; pub mod window_frame; pub mod window_function; @@ -76,6 +77,7 @@ pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; +pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 6b159d71059e7..169044c803f4c 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -160,10 +160,11 @@ pub trait PartitionEvaluator: Debug + Send { /// `num_rows` is requied to correctly compute the output in case /// `values.len() == 0` /// - /// Using this function is an optimization: certain window - /// functions are not affected by the window frame definition, and - /// thus using `evaluate`, DataFusion can skip the (costly) window - /// frame boundary calculation. + /// Implementing this function is an optimization: certain window + /// functions are not affected by the window frame definition or + /// the query doesn't have a frame, and `evaluate` skips the + /// (costly) window frame boundary calculation and the overhead of + /// calling `evaluate` for each output row. /// /// For example, the `LAG` built in window function does not use /// the values of its window frame (it can be computed in one shot @@ -200,8 +201,6 @@ pub trait PartitionEvaluator: Debug + Send { /// Evaluate window function on a range of rows in an input /// partition.x /// - /// Only used for stateful evaluation. - /// /// This is the simplest and most general function to implement /// but also the least performant as it creates output one row at /// a time. It is typically much faster to implement stateful @@ -209,7 +208,7 @@ pub trait PartitionEvaluator: Debug + Send { /// trait. /// /// Returns a [`ScalarValue`] that is the value of the window - /// function within the rangefor the entire partition + /// function within `range` for the entire partition fn evaluate( &mut self, _values: &[ArrayRef], @@ -260,7 +259,8 @@ pub trait PartitionEvaluator: Debug + Send { /// Can the window function be incrementally computed using /// bounded memory? /// - /// If this function returns true, implement [`PartitionEvaluator::evaluate`] + /// If this function returns true, implement [`PartitionEvaluator::evaluate`] and + /// [`PartitionEvaluator::update_state`]. fn supports_bounded_execution(&self) -> bool { false } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs new file mode 100644 index 0000000000000..c6fe510a8fe8a --- /dev/null +++ b/datafusion/expr/src/udwf.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support for user-defined window (UDWF) window functions + +use std::fmt::{self, Debug, Display, Formatter}; + +use crate::{function::PartitionEvaluatorFactory, ReturnTypeFunction, Signature}; + +/// Logical representation of a user-defined window function (UDWF) +/// A UDAF is different from a UDF in that it is stateful across batches. +/// +/// Window Frames: +/// +/// TODO add a diagram here showing the input and the ouput w/ frames +/// (or document that elsewhere and link here) +#[derive(Clone)] +pub struct WindowUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// Return the partition evaluator + pub partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish_non_exhaustive() + } +} + +/// Defines how the WindowUDF is shown to users +impl Display for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PartialEq for WindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl Eq for WindowUDF {} + +impl std::hash::Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl WindowUDF { + /// Create a new WindowUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + partition_evaluator_factory: &PartitionEvaluatorFactory, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + partition_evaluator_factory: partition_evaluator_factory.clone(), + } + } +} diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a5b58d173c1ad..6f30bff69b6ae 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -19,11 +19,12 @@ //! sets of rows that are related to the current query row. //! //! see also -//! use crate::aggregate_function::AggregateFunction; use crate::type_coercion::functions::data_types; -use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility}; +use crate::{ + aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF, +}; use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result}; use std::sync::Arc; @@ -33,11 +34,14 @@ use strum_macros::EnumIter; /// WindowFunction #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WindowFunction { - /// window function that leverages an aggregate function + /// A built in aggregate function that leverages an aggregate function AggregateFunction(AggregateFunction), - /// window function that leverages a built-in window function + /// A a built-in window function BuiltInWindowFunction(BuiltInWindowFunction), + /// A user defined aggregate function AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), } /// Find DataFusion's built-in window function by name. @@ -69,6 +73,7 @@ impl fmt::Display for WindowFunction { WindowFunction::AggregateFunction(fun) => fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunction::WindowUDF(fun) => fun.fmt(f), } } } @@ -166,6 +171,9 @@ pub fn return_type( WindowFunction::AggregateUDF(fun) => { Ok((*(fun.return_type)(input_expr_types)?).clone()) } + WindowFunction::WindowUDF(fun) => { + Ok((*(fun.return_type)(input_expr_types)?).clone()) + } } } @@ -202,6 +210,7 @@ pub fn signature(fun: &WindowFunction) -> Signature { WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun), WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), WindowFunction::AggregateUDF(fun) => fun.signature.clone(), + WindowFunction::WindowUDF(fun) => fun.signature.clone(), } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index afd0f61292014..f08f357ec42c4 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -178,11 +178,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn find_window_func(&self, name: &str) -> Result { window_function::find_df_window_func(name) + // next check user defined aggregates .or_else(|| { self.schema_provider .get_aggregate_meta(name) .map(WindowFunction::AggregateUDF) }) + // next check user defined window functions + .or_else(|| { + self.schema_provider + .get_window_meta(name) + .map(WindowFunction::WindowUDF) + }) .ok_or_else(|| { DataFusionError::Plan(format!("There is no window function named {name}")) }) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ceec01037425f..26ff5466f408b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -22,6 +22,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::field_not_found; +use datafusion_expr::WindowUDF; use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -46,6 +47,8 @@ pub trait ContextProvider { fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type fn get_variable_type(&self, variable_names: &[String]) -> Option;