Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDAF and UDWF support aliases #9489

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2126,10 +2126,16 @@ impl FunctionRegistry for SessionState {
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
udaf.aliases().iter().for_each(|alias| {
self.aggregate_functions.insert(alias.clone(), udaf.clone());
});
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}

fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
udwf.aliases().iter().for_each(|alias| {
self.window_functions.insert(alias.clone(), udwf.clone());
});
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}

Expand All @@ -2144,11 +2150,23 @@ impl FunctionRegistry for SessionState {
}

fn deregister_udaf(&mut self, name: &str) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.aggregate_functions.remove(name))
let udaf = self.aggregate_functions.remove(name);
if let Some(udaf) = &udaf {
for alias in udaf.aliases() {
self.aggregate_functions.remove(alias);
}
}
Ok(udaf)
}

fn deregister_udwf(&mut self, name: &str) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.remove(name))
let udwf = self.window_functions.remove(name);
if let Some(udwf) = &udwf {
for alias in udwf.aliases() {
self.window_functions.remove(alias);
}
}
Ok(udwf)
}
}

Expand Down
37 changes: 37 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use std::sync::{
};

use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
Expand Down Expand Up @@ -320,6 +321,42 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_user_defined_functions_with_alias() -> Result<()> {
let ctx = SessionContext::new();
let arr = Int32Array::from(vec![1]);
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let my_avg = create_udaf(
"dummy",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
)
.with_aliases(vec!["dummy_alias"]);

ctx.register_udaf(my_avg);

let expected = [
"+------------+",
"| dummy(t.i) |",
"+------------+",
"| 1.0 |",
"+------------+",
];

let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
assert_batches_eq!(expected, &result);

let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?;
assert_batches_eq!(expected, &alias_result);

Ok(())
}

#[tokio::test]
async fn test_groups_accumulator() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
odd_counter(val) OVER (PARTITION BY x ORDER BY y) \
from t ORDER BY x, y";

const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \
odd_counter_alias(val) OVER (PARTITION BY x ORDER BY y) \
from t ORDER BY x, y";

/// A query with a window function evaluated over a moving window
const BOUNDED_WINDOW_QUERY: &str =
"SELECT x, y, val, \
Expand Down Expand Up @@ -118,6 +122,35 @@ async fn test_deregister_udwf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_udwf_with_alias() {
let test_state = TestState::new();
let TestContext { ctx, .. } = TestContext::new(test_state);

let expected = vec![
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
"| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |",
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
"| 1 | a | 0 | 1 |",
"| 1 | b | 1 | 1 |",
"| 1 | c | 2 | 1 |",
"| 2 | d | 3 | 2 |",
"| 2 | e | 4 | 2 |",
"| 2 | f | 5 | 2 |",
"| 2 | g | 6 | 2 |",
"| 2 | h | 6 | 2 |",
"| 2 | i | 6 | 2 |",
"| 2 | j | 6 | 2 |",
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(
expected,
&execute(&ctx, UNBOUNDED_WINDOW_QUERY_WITH_ALIAS)
.await
.unwrap()
);
}

/// Basic user defined window function with bounded window
#[tokio::test]
async fn test_udwf_bounded_window_ignores_frame() {
Expand Down Expand Up @@ -491,6 +524,7 @@ impl OddCounter {
signature: Signature,
return_type: DataType,
test_state: Arc<TestState>,
aliases: Vec<String>,
}

impl SimpleWindowUDF {
Expand All @@ -502,6 +536,7 @@ impl OddCounter {
signature,
return_type,
test_state,
aliases: vec!["odd_counter_alias".to_string()],
}
}
}
Expand All @@ -526,6 +561,10 @@ impl OddCounter {
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
Expand Down
6 changes: 6 additions & 0 deletions datafusion/execution/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,15 @@ impl FunctionRegistry for TaskContext {
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
udaf.aliases().iter().for_each(|alias| {
self.aggregate_functions.insert(alias.clone(), udaf.clone());
});
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
udwf.aliases().iter().for_each(|alias| {
self.window_functions.insert(alias.clone(), udwf.clone());
});
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Expand Down
71 changes: 71 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ impl AggregateUDF {
self.inner.clone()
}

/// Adds additional names that can be used to invoke this function, in
/// addition to `name`
///
/// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases))
}

/// creates an [`Expr`] that calls the aggregate function.
///
/// This utility allows using the UDAF without requiring access to
Expand All @@ -139,6 +147,11 @@ impl AggregateUDF {
self.inner.name()
}

/// Returns the aliases for this function.
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}

/// Returns this function's signature (what input types are accepted)
///
/// See [`AggregateUDFImpl::signature`] for more details.
Expand Down Expand Up @@ -277,6 +290,64 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}

/// Returns any aliases (alternate names) for this function.
///
/// Note: `aliases` should only include names other than [`Self::name`].
/// Defaults to `[]` (no aliases)
fn aliases(&self) -> &[String] {
&[]
}
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
aliases: Vec<String>,
}

impl AliasedAggregateUDFImpl {
pub fn new(
inner: Arc<dyn AggregateUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
}
}

impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.inner.name()
}

fn signature(&self) -> &Signature {
self.inner.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}

fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(arg)
}

fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
self.inner.state_type(return_type)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
71 changes: 69 additions & 2 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl WindowUDF {
///
/// See [`WindowUDFImpl`] for a more convenient way to create a
/// `WindowUDF` using trait objects
#[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")]
#[deprecated(since = "34.0.0", note = "please implement WindowUDFImpl instead")]
pub fn new(
name: &str,
signature: &Signature,
Expand Down Expand Up @@ -112,6 +112,14 @@ impl WindowUDF {
self.inner.clone()
}

/// Adds additional names that can be used to invoke this function, in
/// addition to `name`
///
/// If you implement [`WindowUDFImpl`] directly you should return aliases directly.
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), aliases))
}

/// creates a [`Expr`] that calls the window function given
/// the `partition_by`, `order_by`, and `window_frame` definition
///
Expand Down Expand Up @@ -143,6 +151,11 @@ impl WindowUDF {
self.inner.name()
}

/// Returns the aliases for this function.
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}

/// Returns this function's signature (what input types are accepted)
///
/// See [`WindowUDFImpl::signature`] for more details.
Expand Down Expand Up @@ -217,7 +230,7 @@ where
/// fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> { unimplemented!() }
/// }
///
/// // Create a new ScalarUDF from the implementation
/// // Create a new WindowUDF from the implementation
/// let smooth_it = WindowUDF::from(SmoothIt::new());
///
/// // Call the function `add_one(col)`
Expand Down Expand Up @@ -245,6 +258,60 @@ pub trait WindowUDFImpl: Debug + Send + Sync {

/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;

/// Returns any aliases (alternate names) for this function.
///
/// Note: `aliases` should only include names other than [`Self::name`].
/// Defaults to `[]` (no aliases)
fn aliases(&self) -> &[String] {
&[]
}
}

/// WindowUDF that adds an alias to the underlying function. It is better to
/// implement [`WindowUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
struct AliasedWindowUDFImpl {
inner: Arc<dyn WindowUDFImpl>,
aliases: Vec<String>,
}

impl AliasedWindowUDFImpl {
pub fn new(
inner: Arc<dyn WindowUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
}
}

impl WindowUDFImpl for AliasedWindowUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.inner.name()
}

fn signature(&self) -> &Signature {
self.inner.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator()
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers
Expand Down
Loading