From 53064e13a0df5b4d9ebf9a6009c3d7996d39cb36 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Jun 2023 13:27:01 -0400 Subject: [PATCH] RFC: User Defined Window Function sketch --- datafusion-examples/Cargo.toml | 1 + datafusion-examples/README.md | 1 + datafusion-examples/examples/simple_udwf.rs | 74 ++++++++++++++++++ .../core/src/physical_plan/windows/mod.rs | 61 ++++++++++++++- datafusion/core/tests/data/cars.csv | 27 +++++++ datafusion/expr/src/lib.rs | 2 + datafusion/expr/src/udwf.rs | 77 +++++++++++++++++++ datafusion/expr/src/window_function.rs | 16 +++- datafusion/physical-expr/src/window/mod.rs | 1 + 9 files changed, 254 insertions(+), 6 deletions(-) create mode 100644 datafusion-examples/examples/simple_udwf.rs create mode 100644 datafusion/core/tests/data/cars.csv create mode 100644 datafusion/expr/src/udwf.rs diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 31595c980a30..f1249a2fabdd 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false } prost-derive = { version = "0.11", default-features = false } serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.82" +tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } tonic = "0.9" url = "2.2" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index df6ad5a467b6..02dd9c417325 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -57,6 +57,7 @@ cargo run --example csv_sql - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) +- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs new file mode 100644 index 000000000000..f84794e274f7 --- /dev/null +++ b/datafusion-examples/examples/simple_udwf.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::datasource::file_format::options::CsvReadOptions; + +use datafusion::error::Result; +use datafusion::prelude::*; + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = format!("datafusion/core/tests/data/cars.csv"); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function + // `PARTITION BY car`:each distinct value of car (red, and green) should be treated separately + // `ORDER BY time`: within each group (greed or green) the values will be orderd by time + let df = ctx + .sql( + "SELECT car, \ + speed, \ + lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\ + time \ + from cars", + ) + .await?; + // print the results + df.show().await?; + + // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using + let df = ctx.sql("SELECT car, \ + speed, \ + lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\ + time \ + from cars").await?; + // print the results + df.show().await?; + + // todo show how to run dataframe API as well + + Ok(()) +} diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 73a3eb10c28f..8fafbef2e55c 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -26,15 +26,15 @@ use crate::physical_plan::{ udaf, ExecutionPlan, PhysicalExpr, }; use arrow::datatypes::Schema; -use arrow_schema::{SchemaRef, SortOptions}; +use arrow_schema::{DataType, Field, SchemaRef, SortOptions}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ window_function::{BuiltInWindowFunction, WindowFunction}, - WindowFrame, + WindowFrame, WindowUDF, }; use datafusion_physical_expr::window::{ - BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr, + BuiltInWindowFunctionExpr, PartitionEvaluator, SlidingAggregateWindowExpr, }; use std::borrow::Borrow; use std::convert::TryInto; @@ -97,6 +97,12 @@ pub fn create_window_expr( order_by, window_frame, )), + WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), }) } @@ -184,6 +190,55 @@ fn create_built_in_window_expr( }) } +/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function +fn create_udwf_window_expr( + fun: &Arc, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + // need to get the types into an owned vec for some reason + let input_types: Vec<_> = input_schema.fields().iter().map(|f| f.data_type().clone()).collect(); + // figure out the output type + let data_type = (fun.return_type)(&input_types)?; + Ok(Arc::new(WindowUDFExpr { + fun: Arc::clone(fun), + args: args.to_vec(), + name, + data_type, + })) +} + +// Implement BuiltInWindowFunctionExpr for WindowUDF +#[derive(Clone, Debug)] +struct WindowUDFExpr { + fun: Arc, + args: Vec>, + /// Display name + name: String, + /// result type + data_type: Arc, +} + +impl BuiltInWindowFunctionExpr for WindowUDFExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn field(&self) -> Result { + let nullable = false; + Ok(Field::new(&self.name, self.data_type.as_ref().clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn create_evaluator(&self) -> Result> { + todo!() + } +} + pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, diff --git a/datafusion/core/tests/data/cars.csv b/datafusion/core/tests/data/cars.csv new file mode 100644 index 000000000000..24f363ccf432 --- /dev/null +++ b/datafusion/core/tests/data/cars.csv @@ -0,0 +1,27 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 +green,0.0,1996-04-12T12:05:15.000000000 diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5945480aba1d..c5c364986606 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -49,6 +49,7 @@ pub mod tree_node; pub mod type_coercion; mod udaf; mod udf; +mod udwf; pub mod utils; pub mod window_frame; pub mod window_function; @@ -74,6 +75,7 @@ pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; +pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs new file mode 100644 index 000000000000..a1b767fa2804 --- /dev/null +++ b/datafusion/expr/src/udwf.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support for user-defined window (UDWF) window functions + +use std::fmt::{self, Debug, Formatter}; + +use crate::{ReturnTypeFunction, Signature}; + +/// Logical representation of a user-defined window function (UDWF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct WindowUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + // /// actual implementation + // pub accumulator: AccumulatorFunctionImplementation, +} + +impl Debug for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF").finish_non_exhaustive() + } +} + +impl PartialEq for WindowUDF { + fn eq(&self, other: &Self) -> bool { + todo!(); + //self.name == other.name && self.signature == other.signature + } +} + +impl Eq for WindowUDF {} + +impl std::hash::Hash for WindowUDF { + fn hash(&self, state: &mut H) { + // self.name.hash(state); + // self.signature.hash(state); + } +} + +impl WindowUDF { + // /// Create a new WindowUDF + // pub fn new( + // name: &str, + // signature: &Signature, + // return_type: &ReturnTypeFunction, + // accumulator: &AccumulatorFunctionImplementation, + // state_type: &StateTypeFunction, + // ) -> Self { + // Self { + // name: name.to_owned(), + // signature: signature.clone(), + // return_type: return_type.clone(), + // accumulator: accumulator.clone(), + // state_type: state_type.clone(), + // } + // } +} diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a5b58d173c1a..7b5a9e512aa0 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -23,7 +23,9 @@ use crate::aggregate_function::AggregateFunction; use crate::type_coercion::functions::data_types; -use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility}; +use crate::{ + aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF, +}; use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result}; use std::sync::Arc; @@ -33,11 +35,14 @@ use strum_macros::EnumIter; /// WindowFunction #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WindowFunction { - /// window function that leverages an aggregate function + /// A built in aggregate function that leverages an aggregate function AggregateFunction(AggregateFunction), - /// window function that leverages a built-in window function + /// A a built-in window function BuiltInWindowFunction(BuiltInWindowFunction), + /// A user defined aggregate function AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), } /// Find DataFusion's built-in window function by name. @@ -69,6 +74,7 @@ impl fmt::Display for WindowFunction { WindowFunction::AggregateFunction(fun) => fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunction::WindowUDF(fun) => std::fmt::Debug::fmt(fun, f), } } } @@ -166,6 +172,9 @@ pub fn return_type( WindowFunction::AggregateUDF(fun) => { Ok((*(fun.return_type)(input_expr_types)?).clone()) } + WindowFunction::WindowUDF(fun) => { + Ok((*(fun.return_type)(input_expr_types)?).clone()) + } } } @@ -202,6 +211,7 @@ pub fn signature(fun: &WindowFunction) -> Signature { WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun), WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), WindowFunction::AggregateUDF(fun) => fun.signature.clone(), + WindowFunction::WindowUDF(fun) => fun.signature.clone(), } } diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 4c8b8b5a4e4b..e91fd943117b 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -32,6 +32,7 @@ mod window_frame_state; pub use aggregate::PlainAggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; +pub use partition_evaluator::PartitionEvaluator; pub use sliding_aggregate::SlidingAggregateWindowExpr; pub use window_expr::PartitionBatchState; pub use window_expr::PartitionBatches;