Skip to content

Commit

Permalink
RFC: User Defined Window Function sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 9, 2023
1 parent 1af846b commit 53064e1
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 6 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false }
prost-derive = { version = "0.11", default-features = false }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.82"
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
tonic = "0.9"
url = "2.2"
Expand Down
1 change: 1 addition & 0 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cargo run --example csv_sql
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)

## Distributed

Expand Down
74 changes: 74 additions & 0 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::datasource::file_format::options::CsvReadOptions;

use datafusion::error::Result;
use datafusion::prelude::*;

// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
println!("pwd: {}", std::env::current_dir().unwrap().display());
let csv_path = format!("datafusion/core/tests/data/cars.csv");
let read_options = CsvReadOptions::default().has_header(true);

ctx.register_csv("cars", &csv_path, read_options).await?;
Ok(ctx)
}

/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context().await?;

// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
df.show().await?;

// Use SQL to run the new window function
// `PARTITION BY car`:each distinct value of car (red, and green) should be treated separately
// `ORDER BY time`: within each group (greed or green) the values will be orderd by time
let df = ctx
.sql(
"SELECT car, \
speed, \
lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
time \
from cars",
)
.await?;
// print the results
df.show().await?;

// ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
let df = ctx.sql("SELECT car, \
speed, \
lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
time \
from cars").await?;
// print the results
df.show().await?;

// todo show how to run dataframe API as well

Ok(())
}
61 changes: 58 additions & 3 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ use crate::physical_plan::{
udaf, ExecutionPlan, PhysicalExpr,
};
use arrow::datatypes::Schema;
use arrow_schema::{SchemaRef, SortOptions};
use arrow_schema::{DataType, Field, SchemaRef, SortOptions};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{
window_function::{BuiltInWindowFunction, WindowFunction},
WindowFrame,
WindowFrame, WindowUDF,
};
use datafusion_physical_expr::window::{
BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr,
BuiltInWindowFunctionExpr, PartitionEvaluator, SlidingAggregateWindowExpr,
};
use std::borrow::Borrow;
use std::convert::TryInto;
Expand Down Expand Up @@ -97,6 +97,12 @@ pub fn create_window_expr(
order_by,
window_frame,
)),
WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new(
create_udwf_window_expr(fun, args, input_schema, name)?,
partition_by,
order_by,
window_frame,
)),
})
}

Expand Down Expand Up @@ -184,6 +190,55 @@ fn create_built_in_window_expr(
})
}

/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function
fn create_udwf_window_expr(
fun: &Arc<WindowUDF>,
args: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = input_schema.fields().iter().map(|f| f.data_type().clone()).collect();
// figure out the output type
let data_type = (fun.return_type)(&input_types)?;
Ok(Arc::new(WindowUDFExpr {
fun: Arc::clone(fun),
args: args.to_vec(),
name,
data_type,
}))
}

// Implement BuiltInWindowFunctionExpr for WindowUDF
#[derive(Clone, Debug)]
struct WindowUDFExpr {
fun: Arc<WindowUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
/// Display name
name: String,
/// result type
data_type: Arc<DataType>,
}

impl BuiltInWindowFunctionExpr for WindowUDFExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = false;
Ok(Field::new(&self.name, self.data_type.as_ref().clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}

fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
}
}

pub(crate) fn calc_requirements<
T: Borrow<Arc<dyn PhysicalExpr>>,
S: Borrow<PhysicalSortExpr>,
Expand Down
27 changes: 27 additions & 0 deletions datafusion/core/tests/data/cars.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
car,speed,time
red,20.0,1996-04-12T12:05:03.000000000
red,20.3,1996-04-12T12:05:04.000000000
red,21.4,1996-04-12T12:05:05.000000000
red,21.5,1996-04-12T12:05:06.000000000
red,19.0,1996-04-12T12:05:07.000000000
red,18.0,1996-04-12T12:05:08.000000000
red,17.0,1996-04-12T12:05:09.000000000
red,7.0,1996-04-12T12:05:10.000000000
red,7.1,1996-04-12T12:05:11.000000000
red,7.2,1996-04-12T12:05:12.000000000
red,3.0,1996-04-12T12:05:13.000000000
red,1.0,1996-04-12T12:05:14.000000000
red,0.0,1996-04-12T12:05:15.000000000
green,10.0,1996-04-12T12:05:03.000000000
green,10.3,1996-04-12T12:05:04.000000000
green,10.4,1996-04-12T12:05:05.000000000
green,10.5,1996-04-12T12:05:06.000000000
green,11.0,1996-04-12T12:05:07.000000000
green,12.0,1996-04-12T12:05:08.000000000
green,14.0,1996-04-12T12:05:09.000000000
green,15.0,1996-04-12T12:05:10.000000000
green,15.1,1996-04-12T12:05:11.000000000
green,15.2,1996-04-12T12:05:12.000000000
green,8.0,1996-04-12T12:05:13.000000000
green,2.0,1996-04-12T12:05:14.000000000
green,0.0,1996-04-12T12:05:15.000000000
2 changes: 2 additions & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub mod tree_node;
pub mod type_coercion;
mod udaf;
mod udf;
mod udwf;
pub mod utils;
pub mod window_frame;
pub mod window_function;
Expand All @@ -74,6 +75,7 @@ pub use signature::{Signature, TypeSignature, Volatility};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::AggregateUDF;
pub use udf::ScalarUDF;
pub use udwf::WindowUDF;
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
pub use window_function::{BuiltInWindowFunction, WindowFunction};

Expand Down
77 changes: 77 additions & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Support for user-defined window (UDWF) window functions
use std::fmt::{self, Debug, Formatter};

use crate::{ReturnTypeFunction, Signature};

/// Logical representation of a user-defined window function (UDWF)
/// A UDAF is different from a UDF in that it is stateful across batches.
#[derive(Clone)]
pub struct WindowUDF {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
// /// actual implementation
// pub accumulator: AccumulatorFunctionImplementation,
}

impl Debug for WindowUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("WindowUDF").finish_non_exhaustive()
}
}

impl PartialEq for WindowUDF {
fn eq(&self, other: &Self) -> bool {
todo!();
//self.name == other.name && self.signature == other.signature
}
}

impl Eq for WindowUDF {}

impl std::hash::Hash for WindowUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// self.name.hash(state);
// self.signature.hash(state);
}
}

impl WindowUDF {
// /// Create a new WindowUDF
// pub fn new(
// name: &str,
// signature: &Signature,
// return_type: &ReturnTypeFunction,
// accumulator: &AccumulatorFunctionImplementation,
// state_type: &StateTypeFunction,
// ) -> Self {
// Self {
// name: name.to_owned(),
// signature: signature.clone(),
// return_type: return_type.clone(),
// accumulator: accumulator.clone(),
// state_type: state_type.clone(),
// }
// }
}
16 changes: 13 additions & 3 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
use crate::aggregate_function::AggregateFunction;
use crate::type_coercion::functions::data_types;
use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility};
use crate::{
aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF,
};
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
Expand All @@ -33,11 +35,14 @@ use strum_macros::EnumIter;
/// WindowFunction
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunction {
/// window function that leverages an aggregate function
/// A built in aggregate function that leverages an aggregate function
AggregateFunction(AggregateFunction),
/// window function that leverages a built-in window function
/// A a built-in window function
BuiltInWindowFunction(BuiltInWindowFunction),
/// A user defined aggregate function
AggregateUDF(Arc<AggregateUDF>),
/// A user defined aggregate function
WindowUDF(Arc<WindowUDF>),
}

/// Find DataFusion's built-in window function by name.
Expand Down Expand Up @@ -69,6 +74,7 @@ impl fmt::Display for WindowFunction {
WindowFunction::AggregateFunction(fun) => fun.fmt(f),
WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
WindowFunction::WindowUDF(fun) => std::fmt::Debug::fmt(fun, f),
}
}
}
Expand Down Expand Up @@ -166,6 +172,9 @@ pub fn return_type(
WindowFunction::AggregateUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::WindowUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
}
}

Expand Down Expand Up @@ -202,6 +211,7 @@ pub fn signature(fun: &WindowFunction) -> Signature {
WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun),
WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun),
WindowFunction::AggregateUDF(fun) => fun.signature.clone(),
WindowFunction::WindowUDF(fun) => fun.signature.clone(),
}
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod window_frame_state;
pub use aggregate::PlainAggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
pub use built_in_window_function_expr::BuiltInWindowFunctionExpr;
pub use partition_evaluator::PartitionEvaluator;
pub use sliding_aggregate::SlidingAggregateWindowExpr;
pub use window_expr::PartitionBatchState;
pub use window_expr::PartitionBatches;
Expand Down

0 comments on commit 53064e1

Please sign in to comment.