diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index b6269a560386..d0ee38ac90fd 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1594,7 +1594,7 @@ pub fn create_window_expr_with_name( }) .collect::>>()?; if !is_window_valid(window_frame) { - return Err(DataFusionError::Execution(format!( + return Err(DataFusionError::Plan(format!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound ))); diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3fc5b956966f..5476652ade07 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -982,19 +982,20 @@ async fn window_frame_groups_multiple_order_columns() -> Result<()> { async fn window_frame_groups_without_order_by() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; - // execute the query - let df = ctx + // Try executing an erroneous query (the ORDER BY clause is missing in the + // window frame): + let err = ctx .sql( "SELECT SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM aggregate_test_100 ORDER BY c9;", ) - .await?; - let err = df.collect().await.unwrap_err(); + .await + .unwrap_err(); assert_contains!( err.to_string(), - "Execution error: GROUPS mode requires an ORDER BY clause".to_owned() + "Error during planning: GROUPS mode requires an ORDER BY clause".to_owned() ); Ok(()) } @@ -1034,7 +1035,7 @@ async fn window_frame_creation() -> Result<()> { let results = df.collect().await; assert_eq!( results.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" + "Error during planning: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" ); let df = ctx @@ -1047,7 +1048,20 @@ async fn window_frame_creation() -> Result<()> { let results = df.collect().await; assert_eq!( results.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" + "Error during planning: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" + ); + + let err = ctx + .sql( + "SELECT + COUNT(c1) OVER(GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) + FROM aggregate_test_100;", + ) + .await + .unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: GROUPS mode requires an ORDER BY clause" ); Ok(()) @@ -1123,6 +1137,39 @@ async fn test_window_row_number_aggregate() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_window_range_equivalent_frames() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + COUNT(*) OVER(ORDER BY c9, c1 RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt1, + COUNT(*) OVER(ORDER BY c9, c1 RANGE UNBOUNDED PRECEDING) AS cnt2, + COUNT(*) OVER(ORDER BY c9, c1 RANGE CURRENT ROW) AS cnt3, + COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt4, + COUNT(*) OVER(RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt5, + COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS cnt6 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+------+------+------+------+------+------+", + "| c9 | cnt1 | cnt2 | cnt3 | cnt4 | cnt5 | cnt6 |", + "+-----------+------+------+------+------+------+------+", + "| 28774375 | 1 | 1 | 1 | 100 | 100 | 100 |", + "| 63044568 | 1 | 2 | 1 | 100 | 100 | 100 |", + "| 141047417 | 1 | 3 | 1 | 100 | 100 | 100 |", + "| 141680161 | 1 | 4 | 1 | 100 | 100 | 100 |", + "| 145294611 | 1 | 5 | 1 | 100 | 100 | 100 |", + "+-----------+------+------+------+------+------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn test_window_cume_dist() -> Result<()> { let config = SessionConfig::new(); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index bf74d02b7005..c25d2491e45a 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -144,6 +144,36 @@ impl WindowFrame { } } +/// Construct equivalent explicit window frames for implicit corner cases. +/// With this processing, we may assume in downstream code that RANGE/GROUPS +/// frames contain an appropriate ORDER BY clause. +pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // Normally, RANGE frames require an ORDER BY clause with exactly one + // column. However, an ORDER BY clause may be absent in two edge cases. + if (frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + && (frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { + if order_bys == 0 { + frame.units = WindowFrameUnits::Rows; + frame.start_bound = + WindowFrameBound::Preceding(ScalarValue::UInt64(None)); + frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + } + } else { + return Err(DataFusionError::Plan(format!( + "With window frame of type RANGE, the ORDER BY expression must be of length 1, got {}", order_bys))); + } + } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { + return Err(DataFusionError::Plan( + "GROUPS mode requires an ORDER BY clause".to_string(), + )); + }; + Ok(frame) +} + /// There are five ways to describe starting and ending frame boundaries: /// /// 1. UNBOUNDED PRECEDING diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index b53164f66944..70ddb2c7671a 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -104,17 +104,15 @@ impl WindowExpr for BuiltInWindowExpr { let mut row_wise_results = vec![]; let (values, order_bys) = self.get_values_orderbys(batch)?; - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let range = Range { start: 0, end: 0 }; + let mut window_frame_ctx = WindowFrameContext::new( + &self.window_frame, + sort_options, + Range { start: 0, end: 0 }, + ); // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { - let range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - &range, - )?; + let range = + window_frame_ctx.calculate_range(&order_bys, num_rows, idx)?; let value = evaluator.evaluate_inside_range(&values, &range)?; row_wise_results.push(value); } @@ -168,7 +166,13 @@ impl WindowExpr for BuiltInWindowExpr { // We iterate on each row to perform a running calculation. let record_batch = &partition_batch_state.record_batch; let num_rows = record_batch.num_rows(); - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let last_range = state.window_frame_range.clone(); + let mut window_frame_ctx = WindowFrameContext::new( + &self.window_frame, + sort_options.clone(), + // Start search from the last range + last_range, + ); let sort_partition_points = if evaluator.include_rank() { let columns = self.sort_columns(record_batch)?; self.evaluate_partition_points(num_rows, &columns)? @@ -179,13 +183,7 @@ impl WindowExpr for BuiltInWindowExpr { let mut last_range = state.window_frame_range.clone(); for idx in state.last_calculated_index..num_rows { state.window_frame_range = if self.expr.uses_window_frame() { - window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - &state.window_frame_range, - ) + window_frame_ctx.calculate_range(&order_bys, num_rows, idx) } else { evaluator.get_range(state, num_rows) }?; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 065d26fef00e..96e22976b3d8 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -162,18 +162,10 @@ pub trait AggregateWindowExpr: WindowExpr { /// Evaluates the window function against the batch. fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { - let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame()); let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; let mut idx = 0; - self.get_result_column( - &mut accumulator, - batch, - &mut window_frame_ctx, - &mut last_range, - &mut idx, - false, - ) + self.get_result_column(&mut accumulator, batch, &mut last_range, &mut idx, false) } /// Statefully evaluates the window function against the batch. Maintains @@ -207,11 +199,9 @@ pub trait AggregateWindowExpr: WindowExpr { let mut state = &mut window_state.state; let record_batch = &partition_batch_state.record_batch; - let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame()); let out_col = self.get_result_column( accumulator, record_batch, - &mut window_frame_ctx, &mut state.window_frame_range, &mut state.last_calculated_index, !partition_batch_state.is_end, @@ -230,7 +220,6 @@ pub trait AggregateWindowExpr: WindowExpr { &self, accumulator: &mut Box, record_batch: &RecordBatch, - window_frame_ctx: &mut WindowFrameContext, last_range: &mut Range, idx: &mut usize, not_end: bool, @@ -240,15 +229,15 @@ pub trait AggregateWindowExpr: WindowExpr { let length = values[0].len(); let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); + let mut window_frame_ctx = WindowFrameContext::new( + self.get_window_frame(), + sort_options, + // Start search from the last range + last_range.clone(), + ); let mut row_wise_results: Vec = vec![]; while *idx < length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - *idx, - last_range, - )?; + let cur_range = window_frame_ctx.calculate_range(&order_bys, length, *idx)?; // Exit if the range extends all the way: if cur_range.end == length && not_end { break; diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 9cde3cbdf4b5..64abacde49c1 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! This module provides utilities for window frame index calculations depending on the window frame mode: -//! RANGE, ROWS, GROUPS. +//! This module provides utilities for window frame index calculations +//! depending on the window frame mode: RANGE, ROWS, GROUPS. use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{ - compare_rows, find_bisect_point, get_row_at_idx, search_in_slice, -}; +use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -34,14 +32,18 @@ use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] pub enum WindowFrameContext<'a> { - // ROWS-frames are inherently stateless: + /// ROWS frames are inherently stateless. Rows(&'a Arc), - // RANGE-frames will soon have a stateful implementation that is more efficient than a stateless one: + /// RANGE frames are stateful, they store indices specifying where the + /// previous search left off. This amortizes the overall cost to O(n) + /// where n denotes the row count. Range { window_frame: &'a Arc, state: WindowFrameStateRange, }, - // GROUPS-frames have a stateful implementation that is more efficient than a stateless one: + /// GROUPS frames are stateful, they store group boundaries and indices + /// specifying where the previous search left off. This amortizes the + /// overall cost to O(n) where n denotes the row count. Groups { window_frame: &'a Arc, state: WindowFrameStateGroups, @@ -49,13 +51,17 @@ pub enum WindowFrameContext<'a> { } impl<'a> WindowFrameContext<'a> { - /// Create a new default state for the given window frame. - pub fn new(window_frame: &'a Arc) -> Self { + /// Create a new state object for the given window frame. + pub fn new( + window_frame: &'a Arc, + sort_options: Vec, + last_range: Range, + ) -> Self { match window_frame.units { WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame), WindowFrameUnits::Range => WindowFrameContext::Range { window_frame, - state: WindowFrameStateRange::default(), + state: WindowFrameStateRange::new(sort_options, last_range), }, WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, @@ -68,30 +74,23 @@ impl<'a> WindowFrameContext<'a> { pub fn calculate_range( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], length: usize, idx: usize, - last_range: &Range, ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) } - // sort_options is used in RANGE mode calculations because the ordering and the position of the nulls - // have impact on the range calculations and comparison of the rows. + // Sort options is used in RANGE mode calculations because the + // ordering or position of NULLs impact range calculations and + // comparison of rows. WindowFrameContext::Range { window_frame, ref mut state, - } => state.calculate_range( - window_frame, - range_columns, - sort_options, - length, - idx, - last_range, - ), - // sort_options is not used in GROUPS mode calculations as the inequality of two rows is the indicator - // of a group change, and the ordering and the position of the nulls do not have impact on inequality. + } => state.calculate_range(window_frame, range_columns, length, idx), + // Sort options is not used in GROUPS mode calculations as the + // inequality of two rows indicates a group change, and ordering + // or position of NULLs do not impact inequality. WindowFrameContext::Groups { window_frame, ref mut state, @@ -159,22 +158,37 @@ impl<'a> WindowFrameContext<'a> { } } -/// This structure encapsulates all the state information we require as we -/// scan ranges of data while processing window frames. Currently we calculate -/// things from scratch every time, but we will make this incremental in the future. +/// This structure encapsulates all the state information we require as we scan +/// ranges of data while processing RANGE frames. Attribute `last_range` stores +/// the resulting indices from the previous search. Since the indices only +/// advance forward, we start from `last_range` subsequently. Thus, the overall +/// time complexity of linear search amortizes to O(n) where n denotes the total +/// row count. +/// Attribute `sort_options` stores the column ordering specified by the ORDER +/// BY clause. This information is used to calculate the range. #[derive(Debug, Default)] -pub struct WindowFrameStateRange {} +pub struct WindowFrameStateRange { + last_range: Range, + sort_options: Vec, +} impl WindowFrameStateRange { + /// Create a new object to store the search state. + fn new(sort_options: Vec, last_range: Range) -> Self { + Self { + // Stores the search range we calculate for future use. + last_range, + sort_options, + } + } + /// This function calculates beginning/ending indices for the frame of the current row. fn calculate_range( &mut self, window_frame: &Arc, range_columns: &[ArrayRef], - sort_options: &[SortOptions], length: usize, idx: usize, - last_range: &Range, ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { @@ -184,35 +198,23 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, - length, - )? - } - } - WindowFrameBound::CurrentRow => { - if range_columns.is_empty() { - 0 - } else { - self.calculate_index_of_row::( - range_columns, - sort_options, - idx, - None, - last_range, length, )? } } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )?, }; @@ -220,26 +222,16 @@ impl WindowFrameStateRange { WindowFrameBound::Preceding(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )?, - WindowFrameBound::CurrentRow => { - if range_columns.is_empty() { - length - } else { - self.calculate_index_of_row::( - range_columns, - sort_options, - idx, - None, - last_range, - length, - )? - } - } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => { if n.is_null() { // UNBOUNDED FOLLOWING @@ -247,15 +239,16 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )? } } }; + // Store the resulting range so we can start from here subsequently: + self.last_range.start = start; + self.last_range.end = end; Ok(Range { start, end }) } @@ -265,17 +258,20 @@ impl WindowFrameStateRange { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, - last_range: &Range, length: usize, ) -> Result { let current_row_values = get_row_at_idx(range_columns, idx)?; let end_range = if let Some(delta) = delta { - let is_descending: bool = sort_options + let is_descending: bool = self + .sort_options .first() - .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? + .ok_or_else(|| { + DataFusionError::Internal( + "Sort options unexpectedly absent in a window frame".to_string(), + ) + })? .descending; current_row_values @@ -285,7 +281,7 @@ impl WindowFrameStateRange { return Ok(value.clone()); } if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows + // TODO: Handle positive overflows. value.add(delta) } else if value.is_unsigned() && value < delta { // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. @@ -293,7 +289,7 @@ impl WindowFrameStateRange { // change the following statement to use that. value.sub(value) } else { - // TODO: Handle negative overflows + // TODO: Handle negative overflows. value.sub(delta) } }) @@ -302,12 +298,12 @@ impl WindowFrameStateRange { current_row_values }; let search_start = if SIDE { - last_range.start + self.last_range.start } else { - last_range.end + self.last_range.end }; let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare_rows(current, target, sort_options)?; + let cmp = compare_rows(current, target, &self.sort_options)?; Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) }; search_in_slice(range_columns, &end_range, compare_fn, search_start, length) @@ -340,16 +336,15 @@ impl WindowFrameStateRange { // scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { - current_group_idx: u64, + /// A tuple containing group values and the row index where the group ends. + /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to + /// [([1, 1], 2), ([2, 1], 4), ...]. group_start_indices: VecDeque<(Vec, usize)>, - previous_row_values: Option>, - reached_end: bool, - window_frame_end_idx: u64, - window_frame_start_idx: u64, + /// The group index to which the row index belongs. + current_group_idx: usize, } impl WindowFrameStateGroups { - /// This function calculates beginning/ending indices for the frame of the current row. fn calculate_range( &mut self, window_frame: &Arc, @@ -357,662 +352,287 @@ impl WindowFrameStateGroups { length: usize, idx: usize, ) -> Result> { - if range_columns.is_empty() { - return Err(DataFusionError::Execution( - "GROUPS mode requires an ORDER BY clause".to_string(), - )); - } let start = match window_frame.start_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - WindowFrameBound::CurrentRow => self.calculate_index_of_group::( + WindowFrameBound::Preceding(ref n) => { + if n.is_null() { + // UNBOUNDED PRECEDING + 0 + } else { + self.calculate_index_of_row::( + range_columns, + idx, + Some(n), + length, + )? + } + } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( range_columns, idx, - 0, + None, length, )?, - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( - "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'" - ))) - } - // ERRONEOUS FRAMES - WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal( - "Groups should be Uint".to_string(), - )) - } - }; - let end = match window_frame.end_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( - "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'" - ))) - } - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - WindowFrameBound::CurrentRow => self - .calculate_index_of_group::( + WindowFrameBound::Following(ref n) => self + .calculate_index_of_row::( range_columns, idx, - 0, + Some(n), length, )?, - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::( + }; + let end = match window_frame.end_bound { + WindowFrameBound::Preceding(ref n) => self + .calculate_index_of_row::( range_columns, idx, - n, + Some(n), length, )?, - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, - // ERRONEOUS FRAMES - WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal( - "Groups should be Uint".to_string(), - )) + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, + WindowFrameBound::Following(ref n) => { + if n.is_null() { + // UNBOUNDED FOLLOWING + length + } else { + self.calculate_index_of_row::( + range_columns, + idx, + Some(n), + length, + )? + } } }; Ok(Range { start, end }) } - /// This function does the heavy lifting when finding group boundaries. It is meant to be - /// called twice, in succession, to get window frame start and end indices (with `BISECT_SIDE` - /// supplied as false and true, respectively). - fn calculate_index_of_group( + /// This function does the heavy lifting when finding range boundaries. It is meant to be + /// called twice, in succession, to get window frame start and end indices (with `SIDE` + /// supplied as true and false, respectively). Generic argument `SEARCH_SIDE` determines + /// the sign of `delta` (where true/false represents negative/positive respectively). + fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], idx: usize, - delta: u64, + delta: Option<&ScalarValue>, length: usize, ) -> Result { - let current_row_values = range_columns - .iter() - .map(|col| ScalarValue::try_from_array(col, idx)) - .collect::>>()?; - - if BISECT_SIDE { - // When we call this function to get the window frame start index, it tries to initialize - // the internal grouping state if this is not already done before. This initialization takes - // place only when the window frame start index is greater than or equal to zero. In this - // case, the current row is stored in group_start_indices, with row values as the group - // identifier and row index as the start index of the group. - if !self.initialized() { - self.initialize::(delta, range_columns)?; - } - } else if !self.reached_end { - // When we call this function to get the window frame end index, it extends the window - // frame one by one until the current row's window frame end index is reached by finding - // the next group. - self.extend_window_frame_if_necessary::(range_columns, delta)?; - } - // We keep track of previous row values, so that a group change can be identified. - // If there is a group change, the window frame is advanced and shifted by one group. - let group_change = match &self.previous_row_values { - None => false, - Some(values) => ¤t_row_values != values, - }; - if self.previous_row_values.is_none() || group_change { - self.previous_row_values = Some(current_row_values); - } - if group_change { - self.current_group_idx += 1; - self.advance_one_group::(range_columns)?; - self.shift_one_group::(delta); - } - Ok(if self.group_start_indices.is_empty() { - if self.reached_end { - length + let delta = if let Some(delta) = delta { + if let ScalarValue::UInt64(Some(value)) = delta { + *value as usize } else { - 0 - } - } else if BISECT_SIDE { - match self.group_start_indices.get(0) { - Some(&(_, idx)) => idx, - None => 0, + return Err(DataFusionError::Internal( + "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame" + .to_string(), + )); } - } else { - match (self.reached_end, self.group_start_indices.back()) { - (false, Some(&(_, idx))) => idx, - _ => length, - } - }) - } - - fn extend_window_frame_if_necessary( - &mut self, - range_columns: &[ArrayRef], - delta: u64, - ) -> Result<()> { - let current_window_frame_end_idx = if !SEARCH_SIDE { - self.current_group_idx + delta + 1 - } else if self.current_group_idx >= delta { - self.current_group_idx - delta + 1 } else { 0 }; - if current_window_frame_end_idx == 0 { - // the end index of the window frame is still before the first index - return Ok(()); + let mut group_start = 0; + let last_group = self.group_start_indices.back(); + if let Some((_, group_end)) = last_group { + // Start searching from the last group boundary: + group_start = *group_end; } - if self.group_start_indices.is_empty() { - self.initialize_window_frame_start(range_columns)?; - } - while !self.reached_end - && self.window_frame_end_idx <= current_window_frame_end_idx - { - self.advance_one_group::(range_columns)?; + + // Advance groups until `idx` is inside a group: + while idx > group_start { + let group_row = get_row_at_idx(range_columns, group_start)?; + // Find end boundary of the group (search right boundary): + let group_end = search_in_slice( + range_columns, + &group_row, + check_equality, + group_start, + length, + )?; + self.group_start_indices.push_back((group_row, group_end)); + group_start = group_end; } - Ok(()) - } - fn initialize( - &mut self, - delta: u64, - range_columns: &[ArrayRef], - ) -> Result<()> { - if !SEARCH_SIDE { - self.window_frame_start_idx = self.current_group_idx + delta; - self.initialize_window_frame_start(range_columns) - } else if self.current_group_idx >= delta { - self.window_frame_start_idx = self.current_group_idx - delta; - self.initialize_window_frame_start(range_columns) - } else { - Ok(()) + // Update the group index `idx` belongs to: + while self.current_group_idx < self.group_start_indices.len() + && idx >= self.group_start_indices[self.current_group_idx].1 + { + self.current_group_idx += 1; } - } - fn initialize_window_frame_start( - &mut self, - range_columns: &[ArrayRef], - ) -> Result<()> { - let mut group_values = range_columns - .iter() - .map(|col| ScalarValue::try_from_array(col, 0)) - .collect::>>()?; - let mut start_idx: usize = 0; - for _ in 0..self.window_frame_start_idx { - let next_group_and_start_index = - WindowFrameStateGroups::find_next_group_and_start_index( - range_columns, - &group_values, - start_idx, - )?; - if let Some(entry) = next_group_and_start_index { - (group_values, start_idx) = entry; + // Find the group index of the frame boundary: + let group_idx = if SEARCH_SIDE { + if self.current_group_idx > delta { + self.current_group_idx - delta } else { - // not enough groups to generate a window frame - self.window_frame_end_idx = self.window_frame_start_idx; - self.reached_end = true; - return Ok(()); + 0 } - } - self.group_start_indices - .push_back((group_values, start_idx)); - self.window_frame_end_idx = self.window_frame_start_idx + 1; - Ok(()) - } - - fn initialized(&self) -> bool { - self.reached_end || !self.group_start_indices.is_empty() - } - - /// This function advances the window frame by one group. - fn advance_one_group( - &mut self, - range_columns: &[ArrayRef], - ) -> Result<()> { - let last_group_values = self.group_start_indices.back(); - let last_group_values = if let Some(values) = last_group_values { - values } else { - return Ok(()); + self.current_group_idx + delta }; - let next_group_and_start_index = - WindowFrameStateGroups::find_next_group_and_start_index( + + // Extend `group_start_indices` until it includes at least `group_idx`: + while self.group_start_indices.len() <= group_idx && group_start < length { + let group_row = get_row_at_idx(range_columns, group_start)?; + // Find end boundary of the group (search right boundary): + let group_end = search_in_slice( range_columns, - &last_group_values.0, - last_group_values.1, + &group_row, + check_equality, + group_start, + length, )?; - if let Some(entry) = next_group_and_start_index { - self.group_start_indices.push_back(entry); - self.window_frame_end_idx += 1; - } else { - // not enough groups to proceed - self.reached_end = true; + self.group_start_indices.push_back((group_row, group_end)); + group_start = group_end; } - Ok(()) - } - /// This function drops the oldest group from the window frame. - fn shift_one_group(&mut self, delta: u64) { - let current_window_frame_start_idx = if !SEARCH_SIDE { - self.current_group_idx + delta - } else if self.current_group_idx >= delta { - self.current_group_idx - delta - } else { - 0 - }; - if current_window_frame_start_idx > self.window_frame_start_idx { - self.group_start_indices.pop_front(); - self.window_frame_start_idx += 1; - } - } - - /// This function finds the next group and its start index for a given group and start index. - /// It utilizes an exponentially growing step size to find the group boundary. - // TODO: For small group sizes, proceeding one-by-one to find the group change can be more efficient. - // Statistics about previous group sizes can be used to choose one-by-one vs. exponentially growing, - // or even to set the base step_size when exponentially growing. We can also create a benchmark - // implementation to get insights about the crossover point. - fn find_next_group_and_start_index( - range_columns: &[ArrayRef], - current_row_values: &[ScalarValue], - idx: usize, - ) -> Result, usize)>> { - let mut step_size: usize = 1; - let data_size: usize = range_columns - .get(0) - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? - .len(); - let mut low = idx; - let mut high = idx + step_size; - while high < data_size { - let val = range_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, high)) - .collect::>>()?; - if val == current_row_values { - low = high; - step_size *= 2; - high += step_size; - } else { - break; + // Calculate index of the group boundary: + Ok(match (SIDE, SEARCH_SIDE) { + // Window frame start: + (true, _) => { + let group_idx = min(group_idx, self.group_start_indices.len()); + if group_idx > 0 { + // Normally, start at the boundary of the previous group. + self.group_start_indices[group_idx - 1].1 + } else { + // If previous group is out of the table, start at zero. + 0 + } } - } - low = find_bisect_point( - range_columns, - current_row_values, - |current, to_compare| Ok(current == to_compare), - low, - min(high, data_size), - )?; - if low == data_size { - return Ok(None); - } - let val = range_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, low)) - .collect::>>()?; - Ok(Some((val, low))) + // Window frame end, PRECEDING n + (false, true) => { + if self.current_group_idx >= delta { + let group_idx = self.current_group_idx - delta; + self.group_start_indices[group_idx].1 + } else { + // Group is out of the table, therefore end at zero. + 0 + } + } + // Window frame end, FOLLOWING n + (false, false) => { + let group_idx = min( + self.current_group_idx + delta, + self.group_start_indices.len() - 1, + ); + self.group_start_indices[group_idx].1 + } + }) } } +fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result { + Ok(current == target) +} + #[cfg(test)] mod tests { - use arrow::array::Float64Array; - use datafusion_common::ScalarValue; + use crate::window::window_frame_state::WindowFrameStateGroups; + use arrow::array::{ArrayRef, Float64Array}; + use arrow_schema::SortOptions; + use datafusion_common::from_slice::FromSlice; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use std::ops::Range; use std::sync::Arc; - use crate::from_slice::FromSlice; - - use super::*; - - struct TestData { - arrays: Vec, - group_indices: [usize; 6], - num_groups: usize, - num_rows: usize, - next_group_indices: [usize; 5], - } - - fn test_data() -> TestData { - let num_groups: usize = 5; - let num_rows: usize = 6; - let group_indices = [0, 1, 2, 2, 4, 5]; - let next_group_indices = [1, 2, 4, 4, 5]; - - let arrays: Vec = vec![ - Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 9., 10.])), - Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 3., 4.0, 5.0])), - Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 10., 11.0])), - Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 8., 5., 0.0])), - ]; - TestData { - arrays, - group_indices, - num_groups, - num_rows, - next_group_indices, - } - } - - #[test] - fn test_find_next_group_and_start_index() { - let test_data = test_data(); - for (current_idx, next_idx) in test_data.next_group_indices.iter().enumerate() { - let current_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, current_idx)) - .collect::>>() - .unwrap(); - let next_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, *next_idx)) - .collect::>>() - .unwrap(); - let res = WindowFrameStateGroups::find_next_group_and_start_index( - &test_data.arrays, - ¤t_row_values, - current_idx, - ) - .unwrap(); - assert_eq!(res, Some((next_row_values, *next_idx))); - } - let current_idx = test_data.num_rows - 1; - let current_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, current_idx)) - .collect::>>() - .unwrap(); - let res = WindowFrameStateGroups::find_next_group_and_start_index( - &test_data.arrays, - ¤t_row_values, - current_idx, - ) - .unwrap(); - assert_eq!(res, None); - } - - #[test] - fn test_window_frame_groups_preceding_delta_greater_than_partition_size() { - const START: bool = true; - const END: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 10; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, 0); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, 0); - } - } - - #[test] - fn test_window_frame_groups_following_delta_greater_than_partition_size() { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const DELTA: u64 = 10; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); - assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); - assert!(window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); - assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); - assert!(window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); + fn get_test_data() -> (Vec, Vec) { + let range_columns: Vec = vec![Arc::new(Float64Array::from_slice([ + 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11., + ]))]; + let sort_options = vec![SortOptions { + descending: false, + nulls_first: false, + }]; - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.num_rows); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, test_data.num_rows); - } + (range_columns, sort_options) } - #[test] - fn test_window_frame_groups_preceding_and_following_delta_greater_than_partition_size( - ) { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 10; - - let test_data = test_data(); + fn assert_expected( + expected_results: Vec<(Range, usize)>, + window_frame: &Arc, + ) -> Result<()> { let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!( - window_frame_groups.window_frame_end_idx, - test_data.num_groups as u64 - ); - assert!(window_frame_groups.reached_end); - assert_eq!( - window_frame_groups.group_start_indices.len(), - test_data.num_groups - ); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, 0); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, test_data.num_rows); + let (range_columns, _) = get_test_data(); + let n_row = range_columns[0].len(); + for (idx, (expected_range, expected_group_idx)) in + expected_results.into_iter().enumerate() + { + let range = window_frame_groups.calculate_range( + window_frame, + &range_columns, + n_row, + idx, + )?; + assert_eq!(range, expected_range); + assert_eq!(window_frame_groups.current_group_idx, expected_group_idx); } + Ok(()) } #[test] - fn test_window_frame_groups_preceding_and_following_1() { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 1; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 2 * DELTA + 1); - assert!(!window_frame_groups.reached_end); - assert_eq!( - window_frame_groups.group_start_indices.len(), - 2 * DELTA as usize + 1 - ); - - for idx in 0..test_data.num_rows { - let start_idx = if idx < DELTA as usize { - 0 - } else { - test_data.group_indices[idx] - DELTA as usize - }; - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[start_idx]); - let mut end_idx = if idx >= test_data.num_groups { - test_data.num_rows - } else { - test_data.next_group_indices[idx] - }; - for _ in 0..DELTA { - end_idx = if end_idx >= test_data.num_groups { - test_data.num_rows - } else { - test_data.next_group_indices[end_idx] - }; - } - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, end_idx); - } + fn test_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), + }); + let expected_results = vec![ + (Range { start: 0, end: 2 }, 0), + (Range { start: 0, end: 4 }, 1), + (Range { start: 1, end: 5 }, 2), + (Range { start: 1, end: 5 }, 2), + (Range { start: 2, end: 8 }, 3), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 5, end: 9 }, 5), + ]; + assert_expected(expected_results, &window_frame) } #[test] - fn test_window_frame_groups_preceding_1_and_unbounded_following() { - const START: bool = true; - const PRECEDING: bool = true; - const DELTA: u64 = 1; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - for idx in 0..test_data.num_rows { - let start_idx = if idx < DELTA as usize { - 0 - } else { - test_data.group_indices[idx] - DELTA as usize - }; - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[start_idx]); - } + fn test_window_frame_group_boundaries_both_following() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }); + let expected_results = vec![ + (Range:: { start: 1, end: 4 }, 0), + (Range:: { start: 2, end: 5 }, 1), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 5, end: 9 }, 3), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 9, end: 9 }, 5), + ]; + assert_expected(expected_results, &window_frame) } #[test] - fn test_window_frame_groups_current_and_unbounded_following() { - const START: bool = true; - const PRECEDING: bool = true; - const DELTA: u64 = 0; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 1); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 1); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[idx]); - } + fn test_window_frame_group_boundaries_both_preceding() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + }); + let expected_results = vec![ + (Range:: { start: 0, end: 0 }, 0), + (Range:: { start: 0, end: 1 }, 1), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 1, end: 4 }, 3), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 4, end: 8 }, 5), + ]; + assert_expected(expected_results, &window_frame) } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a74874586ce5..498563b2ab47 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -43,9 +43,10 @@ use datafusion_expr::{ regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, substring, tan, to_hex, to_timestamp_micros, to_timestamp_millis, - to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, - Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, - GetIndexedField, GroupingSet, + to_timestamp_seconds, translate, trim, trunc, upper, uuid, + window_frame::regularize, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -907,16 +908,15 @@ pub fn parse_expr( .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame: WindowFrame = window_frame.clone().try_into()?; - if WindowFrameUnits::Range == window_frame.units - && order_by.len() != 1 - { - Err(proto_error("With window frame of type RANGE, the order by expression must be of length 1")) - } else { - Ok(window_frame) - } + let window_frame = window_frame.clone().try_into()?; + regularize(window_frame, order_by.len()) }) - .transpose()?.ok_or_else(||{DataFusionError::Execution("expects somothing".to_string())})?; + .transpose()? + .ok_or_else(|| { + DataFusionError::Execution( + "missing window frame during deserialization".to_string(), + ) + })?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1845d59472c6..c5f23213aa31 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,9 +19,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFrameUnits, WindowFunction, + WindowFunction, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, @@ -65,15 +66,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame .as_ref() .map(|window_frame| { - let window_frame: WindowFrame = window_frame.clone().try_into()?; - if WindowFrameUnits::Range == window_frame.units - && order_by.len() != 1 - { - Err(DataFusionError::Plan(format!( - "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) - } else { - Ok(window_frame) - } + let window_frame = window_frame.clone().try_into()?; + regularize(window_frame, order_by.len()) }) .transpose()?; let window_frame = if let Some(window_frame) = window_frame { diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index c75f93d36932..44c0559ef35a 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2101,27 +2101,6 @@ fn over_order_by_with_window_frame_single_end() { quick_test(sql, expected); } -#[test] -fn over_order_by_with_window_frame_range_order_by_check() { - let sql = "SELECT order_id, MAX(qty) OVER (RANGE UNBOUNDED PRECEDING) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 0\")", - format!("{err:?}") - ); -} - -#[test] -fn over_order_by_with_window_frame_range_order_by_check_2() { - let sql = - "SELECT order_id, MAX(qty) OVER (ORDER BY order_id, qty RANGE UNBOUNDED PRECEDING) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 2\")", - format!("{err:?}") - ); -} - #[test] fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";