diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bec672ab6f6c..0e12bf9da21a 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,6 +19,7 @@ use crate::function_err::generate_signature_error_msg; use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::partition_evaluator::PartitionEvaluator; use crate::type_coercion::functions::data_types; use crate::ColumnarValue; use crate::{ @@ -54,6 +55,12 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; +/// Factory that creates a PartitionEvaluator for the given aggregate, given +/// its return datatype. +pub type PartitionEvaluatorFunctionFactory = + Arc Result> + Send + Sync>; + + macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index c5c364986606..8a556669709a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -53,6 +53,8 @@ mod udwf; pub mod utils; pub mod window_frame; pub mod window_function; +pub mod partition_evaluator; +pub mod window_frame_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs similarity index 91% rename from datafusion/physical-expr/src/window/partition_evaluator.rs rename to datafusion/expr/src/partition_evaluator.rs index db60fdd5f1fa..316b4a5d58be 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -17,14 +17,25 @@ //! Partition evaluation module -use crate::window::window_expr::BuiltinWindowState; -use crate::window::WindowAggState; +use crate::window_frame_state::WindowAggState; use arrow::array::ArrayRef; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; +use std::any::Any; use std::fmt::Debug; use std::ops::Range; + +/// Trait for the state managed by this partition evaluator +/// +/// This follows the existing pattern, but maybe we can improve it :thinking: + +pub trait PartitionState { + /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + /// Partition evaluator for Window Functions /// /// # Background @@ -100,12 +111,9 @@ pub trait PartitionEvaluator: Debug + Send { false } - /// Returns the internal state of the window function - /// - /// Only used for stateful evaluation - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::Default) + /// Returns the internal state of the window function, if any + fn state(&self) -> Result>> { + Ok(None) } /// Updates the internal state for window function @@ -130,7 +138,7 @@ pub trait PartitionEvaluator: Debug + Send { /// Sets the internal state for window function /// /// Only used for stateful evaluation - fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> { + fn set_state(&mut self, state: Box) -> Result<()> { Err(DataFusionError::NotImplemented( "set_state is not implemented for this window function".to_string(), )) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_frame_state.rs similarity index 89% rename from datafusion/physical-expr/src/window/window_frame_state.rs rename to datafusion/expr/src/window_frame_state.rs index e23a58a09b66..3100f3a58f1e 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/expr/src/window_frame_state.rs @@ -19,16 +19,97 @@ //! depending on the window frame mode: RANGE, ROWS, GROUPS. use arrow::array::ArrayRef; +use arrow::compute::{concat}; use arrow::compute::kernels::sort::SortOptions; +use arrow::record_batch::RecordBatch; use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; + +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// Flag indicating whether we have received all data for this partition + pub is_end: bool, + /// Number of rows emitted for each partition + pub n_out_row: usize, +} + + +#[derive(Debug)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub window_frame_range: Range, + pub window_frame_ctx: Option, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// Stores the results calculated by window frame + pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } +} + +impl WindowAggState { + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } +} + /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] pub enum WindowFrameContext { @@ -547,11 +628,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result, - pub window_frame_ctx: Option, - /// The index of the last row that its result is calculated inside the partition record batch buffer. - pub last_calculated_index: usize, - /// The offset of the deleted row number - pub offset_pruned_rows: usize, - /// Stores the results calculated by window frame - pub out_col: ArrayRef, - /// Keeps track of how many rows should be generated to be in sync with input record_batch. - // (For each row in the input record batch we need to generate a window result). - pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition - pub is_end: bool, -} - -impl WindowAggState { - pub fn prune_state(&mut self, n_prune: usize) { - self.window_frame_range = Range { - start: self.window_frame_range.start - n_prune, - end: self.window_frame_range.end - n_prune, - }; - self.last_calculated_index -= n_prune; - self.offset_pruned_rows += n_prune; - - match self.window_frame_ctx.as_mut() { - // Rows have no state do nothing - Some(WindowFrameContext::Rows(_)) => {} - Some(WindowFrameContext::Range { .. }) => {} - Some(WindowFrameContext::Groups { state, .. }) => { - let mut n_group_to_del = 0; - for (_, end_idx) in &state.group_end_indices { - if n_prune < *end_idx { - break; - } - n_group_to_del += 1; - } - state.group_end_indices.drain(0..n_group_to_del); - state - .group_end_indices - .iter_mut() - .for_each(|(_, start_idx)| *start_idx -= n_prune); - state.current_group_idx -= n_group_to_del; - } - None => {} - }; - } -} - -impl WindowAggState { - pub fn update( - &mut self, - out_col: &ArrayRef, - partition_batch_state: &PartitionBatchState, - ) -> Result<()> { - self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; - self.n_row_result_missing = - partition_batch_state.record_batch.num_rows() - self.last_calculated_index; - self.is_end = partition_batch_state.is_end; - Ok(()) - } -} - -/// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] -pub struct PartitionBatchState { - /// The record_batch belonging to current partition - pub record_batch: RecordBatch, - /// Flag indicating whether we have received all data for this partition - pub is_end: bool, - /// Number of rows emitted for each partition - pub n_out_row: usize, -} - /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`,