From 1f151ea1c46d65b74595e7bc33c5cecf9072a285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 7 Mar 2024 16:40:51 +0800 Subject: [PATCH 1/3] UDAF and UDWF support aliases --- datafusion/core/src/execution/context/mod.rs | 22 ++++++++++++++++++-- datafusion/execution/src/task.rs | 6 ++++++ datafusion/expr/src/udaf.rs | 13 ++++++++++++ datafusion/expr/src/udwf.rs | 13 ++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e071c5c80e11..9bca00ffb573 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2126,10 +2126,16 @@ impl FunctionRegistry for SessionState { &mut self, udaf: Arc, ) -> Result>> { + udaf.aliases().iter().for_each(|alias| { + self.aggregate_functions.insert(alias.clone(), udaf.clone()); + }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } fn register_udwf(&mut self, udwf: Arc) -> Result>> { + udwf.aliases().iter().for_each(|alias| { + self.window_functions.insert(alias.clone(), udwf.clone()); + }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } @@ -2144,11 +2150,23 @@ impl FunctionRegistry for SessionState { } fn deregister_udaf(&mut self, name: &str) -> Result>> { - Ok(self.aggregate_functions.remove(name)) + let udaf = self.aggregate_functions.remove(name); + if let Some(udaf) = &udaf { + for alias in udaf.aliases() { + self.aggregate_functions.remove(alias); + } + } + Ok(udaf) } fn deregister_udwf(&mut self, name: &str) -> Result>> { - Ok(self.window_functions.remove(name)) + let udwf = self.window_functions.remove(name); + if let Some(udwf) = &udwf { + for alias in udwf.aliases() { + self.window_functions.remove(alias); + } + } + Ok(udwf) } } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index b39b4a00327b..cae410655d10 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -207,9 +207,15 @@ impl FunctionRegistry for TaskContext { &mut self, udaf: Arc, ) -> Result>> { + udaf.aliases().iter().for_each(|alias| { + self.aggregate_functions.insert(alias.clone(), udaf.clone()); + }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } fn register_udwf(&mut self, udwf: Arc) -> Result>> { + udwf.aliases().iter().for_each(|alias| { + self.window_functions.insert(alias.clone(), udwf.clone()); + }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } fn register_udf(&mut self, udf: Arc) -> Result>> { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e56723063e41..63edb083f331 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -139,6 +139,11 @@ impl AggregateUDF { self.inner.name() } + /// Returns the aliases for this function. + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + /// Returns this function's signature (what input types are accepted) /// /// See [`AggregateUDFImpl::signature`] for more details. @@ -277,6 +282,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn create_groups_accumulator(&self) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } + + /// Returns any aliases (alternate names) for this function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 3ab40fe70a91..3442d99a97fe 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -143,6 +143,11 @@ impl WindowUDF { self.inner.name() } + /// Returns the aliases for this function. + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + /// Returns this function's signature (what input types are accepted) /// /// See [`WindowUDFImpl::signature`] for more details. @@ -245,6 +250,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator(&self) -> Result>; + + /// Returns any aliases (alternate names) for this function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers From a68c7b133bff5641a917a51678b8ec05aa618a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 11 Mar 2024 11:01:07 +0800 Subject: [PATCH 2/3] Add tests for udaf and udwf aliases --- .../user_defined/user_defined_aggregates.rs | 37 ++++++++++++ .../user_defined_window_functions.rs | 39 +++++++++++++ datafusion/expr/src/udaf.rs | 58 +++++++++++++++++++ datafusion/expr/src/udwf.rs | 58 ++++++++++++++++++- 4 files changed, 190 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 9e231d25f298..3f40c55a3ed7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -27,6 +27,7 @@ use std::sync::{ }; use datafusion::datasource::MemTable; +use datafusion::test_util::plan_and_collect; use datafusion::{ arrow::{ array::{ArrayRef, Float64Array, TimestampNanosecondArray}, @@ -320,6 +321,42 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_functions_with_alias() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let my_avg = create_udaf( + "dummy", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ) + .with_aliases(vec!["dummy_alias"]); + + ctx.register_udaf(my_avg); + + let expected = [ + "+------------+", + "| dummy(t.i) |", + "+------------+", + "| 1.0 |", + "+------------+", + ]; + + let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; + assert_batches_eq!(expected, &result); + + let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; + assert_batches_eq!(expected, &alias_result); + + Ok(()) +} + #[tokio::test] async fn test_groups_accumulator() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index cfd74f8861e3..f4865a394059 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -41,6 +41,10 @@ const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ odd_counter(val) OVER (PARTITION BY x ORDER BY y) \ from t ORDER BY x, y"; +const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \ + odd_counter_alias(val) OVER (PARTITION BY x ORDER BY y) \ + from t ORDER BY x, y"; + /// A query with a window function evaluated over a moving window const BOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -118,6 +122,35 @@ async fn test_deregister_udwf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_udwf_with_alias() { + let test_state = TestState::new(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 2 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY_WITH_ALIAS) + .await + .unwrap() + ); +} + /// Basic user defined window function with bounded window #[tokio::test] async fn test_udwf_bounded_window_ignores_frame() { @@ -491,6 +524,7 @@ impl OddCounter { signature: Signature, return_type: DataType, test_state: Arc, + aliases: Vec, } impl SimpleWindowUDF { @@ -502,6 +536,7 @@ impl OddCounter { signature, return_type, test_state, + aliases: vec!["odd_counter_alias".to_string()], } } } @@ -526,6 +561,10 @@ impl OddCounter { fn partition_evaluator(&self) -> Result> { Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) } + + fn aliases(&self) -> &[String] { + &self.aliases + } } ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 63edb083f331..c46dd9cd3a6f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -118,6 +118,14 @@ impl AggregateUDF { self.inner.clone() } + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases)) + } + /// creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to @@ -292,6 +300,56 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { } } +/// AggregateUDF that adds an alias to the underlying function. It is better to +/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedAggregateUDFImpl { + inner: Arc, + aliases: Vec, +} + +impl AliasedAggregateUDFImpl { + pub fn new( + inner: Arc, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl AggregateUDFImpl for AliasedAggregateUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + self.inner.accumulator(arg) + } + + fn state_type(&self, return_type: &DataType) -> Result> { + self.inner.state_type(return_type) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers /// of the older API pub struct AggregateUDFLegacyWrapper { diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 3442d99a97fe..d3925f2e1925 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -80,7 +80,7 @@ impl WindowUDF { /// /// See [`WindowUDFImpl`] for a more convenient way to create a /// `WindowUDF` using trait objects - #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] + #[deprecated(since = "34.0.0", note = "please implement WindowUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -112,6 +112,14 @@ impl WindowUDF { self.inner.clone() } + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`WindowUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), aliases)) + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -222,7 +230,7 @@ where /// fn partition_evaluator(&self) -> Result> { unimplemented!() } /// } /// -/// // Create a new ScalarUDF from the implementation +/// // Create a new WindowUDF from the implementation /// let smooth_it = WindowUDF::from(SmoothIt::new()); /// /// // Call the function `add_one(col)` @@ -260,6 +268,52 @@ pub trait WindowUDFImpl: Debug + Send + Sync { } } +/// WindowUDF that adds an alias to the underlying function. It is better to +/// implement [`WindowUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedWindowUDFImpl { + inner: Arc, + aliases: Vec, +} + +impl AliasedWindowUDFImpl { + pub fn new( + inner: Arc, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl WindowUDFImpl for AliasedWindowUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn partition_evaluator(&self) -> Result> { + self.inner.partition_evaluator() + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers /// of the older API (see /// for more details) From f3ec4fcf9c056c872a46029544d648424ea20168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 11 Mar 2024 11:22:28 +0800 Subject: [PATCH 3/3] Fix clippy lint --- .../core/tests/user_defined/user_defined_window_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index f4865a394059..3c607301fc98 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -125,7 +125,7 @@ async fn test_deregister_udwf() -> Result<()> { #[tokio::test] async fn test_udwf_with_alias() { let test_state = TestState::new(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, .. } = TestContext::new(test_state); let expected = vec![ "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",