From edae175e7a6685a15aa4429a0e0ba6224bb64ea4 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 17 Jan 2023 17:28:09 +0300 Subject: [PATCH 01/18] add naive linear search --- datafusion/common/src/bisect.rs | 49 +++++++++++++++++++ .../src/window/window_frame_state.rs | 10 +++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/bisect.rs b/datafusion/common/src/bisect.rs index 796598be2cd0..1b55889b30a6 100644 --- a/datafusion/common/src/bisect.rs +++ b/datafusion/common/src/bisect.rs @@ -74,6 +74,55 @@ pub fn bisect( find_bisect_point(item_columns, target, compare_fn, low, high) } +/// This function implements both bisect_left and bisect_right, having the same +/// semantics with the Python Standard Library. To use bisect_left, supply true +/// as the template argument. To use bisect_right, supply false as the template argument. +pub fn linear_search( + item_columns: &[ArrayRef], + target: &[ScalarValue], + sort_options: &[SortOptions], +) -> Result { + let low: usize = 0; + let high: usize = item_columns + .get(0) + .ok_or_else(|| { + DataFusionError::Internal("Column array shouldn't be empty".to_string()) + })? + .len(); + let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { + let cmp = compare(current, target, sort_options)?; + Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) + }; + find_linear_point(item_columns, target, compare_fn, low, high) +} + +/// This function searches for a tuple of target values among the given rows using the bisection algorithm. +/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`), +/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively +/// bisect the input. +pub fn find_linear_point( + item_columns: &[ArrayRef], + target: &[ScalarValue], + compare_fn: F, + mut low: usize, + mut high: usize, +) -> Result +where + F: Fn(&[ScalarValue], &[ScalarValue]) -> Result, +{ + while low < high { + let val = item_columns + .iter() + .map(|arr| ScalarValue::try_from_array(arr, low)) + .collect::>>()?; + if !compare_fn(&val, target)? { + break; + } + low += 1; + } + Ok(low) +} + /// This function searches for a tuple of target values among the given rows using the bisection algorithm. /// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`), /// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 9c559cabd170..41b31f020f13 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::bisect::{bisect, find_bisect_point}; +use datafusion_common::bisect::{bisect, find_bisect_point, linear_search}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -286,7 +286,13 @@ impl WindowFrameStateRange { current_row_values }; // `BISECT_SIDE` true means bisect_left, false means bisect_right - bisect::(range_columns, &end_range, sort_options) + let linear = + linear_search::(range_columns, &end_range, sort_options)?; + // `BISECT_SIDE` true means bisect_left, false means bisect_right + // let res = bisect::(range_columns, &end_range, sort_options)?; + // println!("linear: {:?}", linear); + // println!("bisect: {:?}", res); + Ok(linear) } } From e738e9ffa7f3b01b083d33aaae4a2d9a173381f8 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 17 Jan 2023 17:45:37 +0300 Subject: [PATCH 02/18] Add last range to decrease search size --- datafusion/common/src/bisect.rs | 8 +++++- .../src/window/window_frame_state.rs | 26 +++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/datafusion/common/src/bisect.rs b/datafusion/common/src/bisect.rs index 1b55889b30a6..814ae561ecc8 100644 --- a/datafusion/common/src/bisect.rs +++ b/datafusion/common/src/bisect.rs @@ -21,6 +21,7 @@ use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::ArrayRef; use arrow::compute::SortOptions; use std::cmp::Ordering; +use std::ops::Range; /// This function compares two tuples depending on the given sort options. fn compare( @@ -81,8 +82,13 @@ pub fn linear_search( item_columns: &[ArrayRef], target: &[ScalarValue], sort_options: &[SortOptions], + last_range: &Range, ) -> Result { - let low: usize = 0; + let low: usize = if SIDE { + last_range.start + } else { + last_range.end + }; let high: usize = item_columns .get(0) .ok_or_else(|| { diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 41b31f020f13..fbe5759cdc88 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -159,7 +159,9 @@ impl<'a> WindowFrameContext<'a> { /// scan ranges of data while processing window frames. Currently we calculate /// things from scratch every time, but we will make this incremental in the future. #[derive(Debug, Default)] -pub struct WindowFrameStateRange {} +pub struct WindowFrameStateRange { + last_range: Range, +} impl WindowFrameStateRange { /// This function calculates beginning/ending indices for the frame of the current row. @@ -239,13 +241,15 @@ impl WindowFrameStateRange { } } }; - Ok(Range { start, end }) + self.last_range = Range { start, end }; + println!("self.last_range: {:?}", self.last_range); + Ok(self.last_range.clone()) } /// This function does the heavy lifting when finding range boundaries. It is meant to be - /// called twice, in succession, to get window frame start and end indices (with `BISECT_SIDE` - /// supplied as false and true, respectively). - fn calculate_index_of_row( + /// called twice, in succession, to get window frame start and end indices (with `SIDE` + /// supplied as true and false, respectively). + fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], sort_options: &[SortOptions], @@ -286,12 +290,12 @@ impl WindowFrameStateRange { current_row_values }; // `BISECT_SIDE` true means bisect_left, false means bisect_right - let linear = - linear_search::(range_columns, &end_range, sort_options)?; - // `BISECT_SIDE` true means bisect_left, false means bisect_right - // let res = bisect::(range_columns, &end_range, sort_options)?; - // println!("linear: {:?}", linear); - // println!("bisect: {:?}", res); + let linear = linear_search::( + range_columns, + &end_range, + sort_options, + &self.last_range, + )?; Ok(linear) } } From 970edf3839eadf3890dd8a21f1b93c52077ccb76 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 18 Jan 2023 09:36:56 +0300 Subject: [PATCH 03/18] minor changes --- datafusion/common/src/lib.rs | 2 +- datafusion/common/src/{bisect.rs => utils.rs} | 109 ++++++++---------- .../physical-expr/src/window/aggregate.rs | 9 +- .../physical-expr/src/window/built_in.rs | 8 +- .../physical-expr/src/window/nth_value.rs | 6 +- .../src/window/partition_evaluator.rs | 2 +- .../src/window/sliding_aggregate.rs | 1 + .../src/window/window_frame_state.rs | 37 +++--- 8 files changed, 92 insertions(+), 82 deletions(-) rename datafusion/common/src/{bisect.rs => utils.rs} (86%) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 2935cd79639d..5be3f9411c69 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub mod bisect; pub mod cast; mod column; pub mod config; @@ -30,6 +29,7 @@ pub mod scalar; pub mod stats; mod table_reference; pub mod test_util; +pub mod utils; use arrow::compute::SortOptions; pub use column::Column; diff --git a/datafusion/common/src/bisect.rs b/datafusion/common/src/utils.rs similarity index 86% rename from datafusion/common/src/bisect.rs rename to datafusion/common/src/utils.rs index 814ae561ecc8..eb908548d9cb 100644 --- a/datafusion/common/src/bisect.rs +++ b/datafusion/common/src/utils.rs @@ -21,7 +21,14 @@ use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::ArrayRef; use arrow::compute::SortOptions; use std::cmp::Ordering; -use std::ops::Range; + +/// Given column vectors, returns row at `idx` +fn get_row_at_idx(item_columns: &[ArrayRef], idx: usize) -> Result> { + item_columns + .iter() + .map(|arr| ScalarValue::try_from_array(arr, idx)) + .collect::>>() +} /// This function compares two tuples depending on the given sort options. fn compare( @@ -75,38 +82,11 @@ pub fn bisect( find_bisect_point(item_columns, target, compare_fn, low, high) } -/// This function implements both bisect_left and bisect_right, having the same -/// semantics with the Python Standard Library. To use bisect_left, supply true -/// as the template argument. To use bisect_right, supply false as the template argument. -pub fn linear_search( - item_columns: &[ArrayRef], - target: &[ScalarValue], - sort_options: &[SortOptions], - last_range: &Range, -) -> Result { - let low: usize = if SIDE { - last_range.start - } else { - last_range.end - }; - let high: usize = item_columns - .get(0) - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? - .len(); - let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare(current, target, sort_options)?; - Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) - }; - find_linear_point(item_columns, target, compare_fn, low, high) -} - /// This function searches for a tuple of target values among the given rows using the bisection algorithm. /// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`), /// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively /// bisect the input. -pub fn find_linear_point( +pub fn find_bisect_point( item_columns: &[ArrayRef], target: &[ScalarValue], compare_fn: F, @@ -117,43 +97,44 @@ where F: Fn(&[ScalarValue], &[ScalarValue]) -> Result, { while low < high { - let val = item_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, low)) - .collect::>>()?; - if !compare_fn(&val, target)? { - break; + let mid = ((high - low) / 2) + low; + let val = get_row_at_idx(item_columns, mid)?; + if compare_fn(&val, target)? { + low = mid + 1; + } else { + high = mid; } - low += 1; } Ok(low) } -/// This function searches for a tuple of target values among the given rows using the bisection algorithm. -/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`), -/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively -/// bisect the input. -pub fn find_bisect_point( +/// This function implements linear_search, It starts searching from row at `start` index +/// of `item_columns` until last row of the `item_columns`. It assumes `item_columns` is sorted +/// according to `sort_options` and returns would insertion position of the `target`. +/// `SIDE` is `true` means left insertion is applied. +/// `SIDE` is `false` means right insertion is applied. +pub fn linear_search( item_columns: &[ArrayRef], target: &[ScalarValue], - compare_fn: F, + sort_options: &[SortOptions], mut low: usize, - mut high: usize, -) -> Result -where - F: Fn(&[ScalarValue], &[ScalarValue]) -> Result, -{ +) -> Result { + let high: usize = item_columns + .get(0) + .ok_or_else(|| { + DataFusionError::Internal("Column array shouldn't be empty".to_string()) + })? + .len(); + let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { + let cmp = compare(current, target, sort_options)?; + Ok::(if SIDE { cmp.is_lt() } else { cmp.is_le() }) + }; while low < high { - let mid = ((high - low) / 2) + low; - let val = item_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, mid)) - .collect::>>()?; - if compare_fn(&val, target)? { - low = mid + 1; - } else { - high = mid; + let val = get_row_at_idx(item_columns, low)?; + if !compare_fn(&val, target)? { + break; } + low += 1; } Ok(low) } @@ -170,7 +151,7 @@ mod tests { use super::*; #[test] - fn test_bisect_left_and_right() { + fn test_bisect_linear_left_and_right() { let arrays: Vec = vec![ Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.])), Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 4.0, 5.0])), @@ -205,6 +186,11 @@ mod tests { assert_eq!(res, 2); let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); assert_eq!(res, 3); + let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + assert_eq!(res, 2); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + assert_eq!(res, 3); } #[test] @@ -241,7 +227,7 @@ mod tests { } #[test] - fn test_bisect_left_and_right_diff_sort() { + fn test_bisect_linear_left_and_right_diff_sort() { // Descending, left let arrays: Vec = vec![Arc::new(Float64Array::from_slice([ 4.0, 3.0, 2.0, 1.0, 0.0, @@ -311,5 +297,12 @@ mod tests { let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); assert_eq!(res, 2); + + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + assert_eq!(res, 3); + + let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + assert_eq!(res, 2); } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index df61e7cc8fbb..fe725f2d71a5 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -106,8 +106,13 @@ impl WindowExpr for AggregateWindowExpr { // We iterate on each row to perform a running calculation. // First, cur_range is calculated, then it is compared with last_range. for i in 0..length { - let cur_range = - window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; + let cur_range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + length, + i, + &last_range, + )?; let value = if cur_range.end == cur_range.start { // We produce None if the window is empty. ScalarValue::try_from(self.aggregate.field()?.data_type())? diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index f0484b790fbc..bfb8e835b84d 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -34,6 +34,7 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{WindowFrame, WindowFrameUnits}; use std::any::Any; +use std::ops::Range; use std::sync::Arc; /// A window expr that takes the form of a built in window function @@ -104,15 +105,17 @@ impl WindowExpr for BuiltInWindowExpr { let length = batch.num_rows(); let (values, order_bys) = self.get_values_orderbys(batch)?; let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..length { - let range = window_frame_ctx.calculate_range( + range = window_frame_ctx.calculate_range( &order_bys, &sort_options, num_rows, idx, + &range, )?; - let value = evaluator.evaluate_inside_range(&values, range)?; + let value = evaluator.evaluate_inside_range(&values, &range)?; row_wise_results.push(value); } ScalarValue::iter_to_array(row_wise_results.into_iter()) @@ -185,6 +188,7 @@ impl WindowExpr for BuiltInWindowExpr { &sort_options, num_rows, idx, + &state.window_frame_range, ) } else { evaluator.get_range(state, num_rows) diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index c3c3b55d4e88..c40a4fa7de03 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -176,13 +176,13 @@ impl PartitionEvaluator for NthValueEvaluator { } fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { - self.evaluate_inside_range(values, self.state.range.clone()) + self.evaluate_inside_range(values, &self.state.range) } fn evaluate_inside_range( &self, values: &[ArrayRef], - range: Range, + range: &Range, ) -> Result { // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take single column, values will have size 1 let arr = &values[0]; @@ -227,7 +227,7 @@ mod tests { let evaluator = expr.create_evaluator()?; let values = expr.evaluate_args(&batch)?; let result = ranges - .into_iter() + .iter() .map(|range| evaluator.evaluate_inside_range(&values, range)) .into_iter() .collect::>>()?; diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index e6cead76d13d..44fbb2d94567 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -83,7 +83,7 @@ pub trait PartitionEvaluator: Debug + Send { fn evaluate_inside_range( &self, _values: &[ArrayRef], - _range: Range, + _range: &Range, ) -> Result { Err(DataFusionError::NotImplemented( "evaluate_inside_range is not implemented by default".into(), diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 587c313e31bd..a429f658cb5d 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -268,6 +268,7 @@ impl SlidingAggregateWindowExpr { &sort_options, length, *idx, + last_range, )?; // Exit if range end index is length, need kind of flag to stop if cur_range.end == length && !is_end { diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index fbe5759cdc88..85c8b07e3165 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::bisect::{bisect, find_bisect_point, linear_search}; +use datafusion_common::utils::{find_bisect_point, linear_search}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -69,6 +69,7 @@ impl<'a> WindowFrameContext<'a> { sort_options: &[SortOptions], length: usize, idx: usize, + last_range: &Range, ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { @@ -85,6 +86,7 @@ impl<'a> WindowFrameContext<'a> { sort_options, length, idx, + last_range, ), // sort_options is not used in GROUPS mode calculations as the inequality of two rows is the indicator // of a group change, and the ordering and the position of the nulls do not have impact on inequality. @@ -159,9 +161,7 @@ impl<'a> WindowFrameContext<'a> { /// scan ranges of data while processing window frames. Currently we calculate /// things from scratch every time, but we will make this incremental in the future. #[derive(Debug, Default)] -pub struct WindowFrameStateRange { - last_range: Range, -} +pub struct WindowFrameStateRange {} impl WindowFrameStateRange { /// This function calculates beginning/ending indices for the frame of the current row. @@ -172,6 +172,7 @@ impl WindowFrameStateRange { sort_options: &[SortOptions], length: usize, idx: usize, + last_range: &Range, ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { @@ -184,6 +185,7 @@ impl WindowFrameStateRange { sort_options, idx, Some(n), + last_range, )? } } @@ -196,6 +198,7 @@ impl WindowFrameStateRange { sort_options, idx, None, + last_range, )? } } @@ -205,6 +208,7 @@ impl WindowFrameStateRange { sort_options, idx, Some(n), + last_range, )?, }; let end = match window_frame.end_bound { @@ -214,6 +218,7 @@ impl WindowFrameStateRange { sort_options, idx, Some(n), + last_range, )?, WindowFrameBound::CurrentRow => { if range_columns.is_empty() { @@ -224,6 +229,7 @@ impl WindowFrameStateRange { sort_options, idx, None, + last_range, )? } } @@ -237,13 +243,12 @@ impl WindowFrameStateRange { sort_options, idx, Some(n), + last_range, )? } } }; - self.last_range = Range { start, end }; - println!("self.last_range: {:?}", self.last_range); - Ok(self.last_range.clone()) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding range boundaries. It is meant to be @@ -255,6 +260,7 @@ impl WindowFrameStateRange { sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, + last_range: &Range, ) -> Result { let current_row_values = range_columns .iter() @@ -289,14 +295,15 @@ impl WindowFrameStateRange { } else { current_row_values }; - // `BISECT_SIDE` true means bisect_left, false means bisect_right - let linear = linear_search::( - range_columns, - &end_range, - sort_options, - &self.last_range, - )?; - Ok(linear) + let search_start = if SIDE { + last_range.start + } else { + last_range.end + }; + // `SIDE` true means from left insert, false means right insert + let res = + linear_search::(range_columns, &end_range, sort_options, search_start)?; + Ok(res) } } From a92a7e22d8c956cae8b0fd29ebea3433cda0fb7e Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 18 Jan 2023 13:46:17 +0300 Subject: [PATCH 04/18] add low, high arguments --- datafusion/common/src/utils.rs | 80 +++++++++++-------- .../src/window/window_frame_state.rs | 16 +++- 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index eb908548d9cb..3d487eecf172 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -63,18 +63,14 @@ fn compare( /// This function implements both bisect_left and bisect_right, having the same /// semantics with the Python Standard Library. To use bisect_left, supply true /// as the template argument. To use bisect_right, supply false as the template argument. +/// It searches `item_columns` between rows `low` and `high`. pub fn bisect( item_columns: &[ArrayRef], target: &[ScalarValue], sort_options: &[SortOptions], + low: usize, + high: usize, ) -> Result { - let low: usize = 0; - let high: usize = item_columns - .get(0) - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? - .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare(current, target, sort_options)?; Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) @@ -108,9 +104,9 @@ where Ok(low) } -/// This function implements linear_search, It starts searching from row at `start` index -/// of `item_columns` until last row of the `item_columns`. It assumes `item_columns` is sorted -/// according to `sort_options` and returns would insertion position of the `target`. +/// This function implements linear_search, It searches `item_columns` between rows `low` and `high`. +/// It assumes `item_columns` is sorted according to `sort_options` +/// and returns insertion position of the `target` in the `item_columns`. /// `SIDE` is `true` means left insertion is applied. /// `SIDE` is `false` means right insertion is applied. pub fn linear_search( @@ -118,13 +114,8 @@ pub fn linear_search( target: &[ScalarValue], sort_options: &[SortOptions], mut low: usize, + high: usize, ) -> Result { - let high: usize = item_columns - .get(0) - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? - .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare(current, target, sort_options)?; Ok::(if SIDE { cmp.is_lt() } else { cmp.is_le() }) @@ -141,7 +132,7 @@ pub fn linear_search( #[cfg(test)] mod tests { - use arrow::array::Float64Array; + use arrow::array::{Array, Float64Array}; use std::sync::Arc; use crate::from_slice::FromSlice; @@ -151,7 +142,7 @@ mod tests { use super::*; #[test] - fn test_bisect_linear_left_and_right() { + fn test_bisect_linear_left_and_right() -> Result<()> { let arrays: Vec = vec![ Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.])), Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 4.0, 5.0])), @@ -182,15 +173,16 @@ mod tests { nulls_first: true, }, ]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let n_row = arrays[0].len(); + let res: usize = bisect::(&arrays, &search_tuple, &ords, 0, n_row)?; assert_eq!(res, 2); - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = bisect::(&arrays, &search_tuple, &ords, 0, n_row)?; assert_eq!(res, 3); - let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0, n_row)?; assert_eq!(res, 2); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0, n_row)?; assert_eq!(res, 3); + Ok(()) } #[test] @@ -227,7 +219,7 @@ mod tests { } #[test] - fn test_bisect_linear_left_and_right_diff_sort() { + fn test_bisect_linear_left_and_right_diff_sort() -> Result<()> { // Descending, left let arrays: Vec = vec![Arc::new(Float64Array::from_slice([ 4.0, 3.0, 2.0, 1.0, 0.0, @@ -237,7 +229,11 @@ mod tests { descending: true, nulls_first: true, }]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + assert_eq!(res, 0); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 0); // Descending, right @@ -249,7 +245,11 @@ mod tests { descending: true, nulls_first: true, }]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + assert_eq!(res, 1); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 1); // Ascending, left @@ -260,7 +260,11 @@ mod tests { descending: false, nulls_first: true, }]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + assert_eq!(res, 1); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 1); // Ascending, right @@ -271,7 +275,11 @@ mod tests { descending: false, nulls_first: true, }]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + assert_eq!(res, 2); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 2); let arrays: Vec = vec![ @@ -292,17 +300,19 @@ mod tests { nulls_first: true, }, ]; - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 3); - - let res: usize = bisect::(&arrays, &search_tuple, &ords).unwrap(); - assert_eq!(res, 2); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 3); - let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0).unwrap(); + let res: usize = + bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + assert_eq!(res, 2); + let res: usize = + linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; assert_eq!(res, 2); + Ok(()) } } diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 85c8b07e3165..7e31a94f31df 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -186,6 +186,7 @@ impl WindowFrameStateRange { idx, Some(n), last_range, + length, )? } } @@ -199,6 +200,7 @@ impl WindowFrameStateRange { idx, None, last_range, + length, )? } } @@ -209,6 +211,7 @@ impl WindowFrameStateRange { idx, Some(n), last_range, + length, )?, }; let end = match window_frame.end_bound { @@ -219,6 +222,7 @@ impl WindowFrameStateRange { idx, Some(n), last_range, + length, )?, WindowFrameBound::CurrentRow => { if range_columns.is_empty() { @@ -230,6 +234,7 @@ impl WindowFrameStateRange { idx, None, last_range, + length, )? } } @@ -244,6 +249,7 @@ impl WindowFrameStateRange { idx, Some(n), last_range, + length, )? } } @@ -261,6 +267,7 @@ impl WindowFrameStateRange { idx: usize, delta: Option<&ScalarValue>, last_range: &Range, + length: usize, ) -> Result { let current_row_values = range_columns .iter() @@ -301,8 +308,13 @@ impl WindowFrameStateRange { last_range.end }; // `SIDE` true means from left insert, false means right insert - let res = - linear_search::(range_columns, &end_range, sort_options, search_start)?; + let res = linear_search::( + range_columns, + &end_range, + sort_options, + search_start, + length, + )?; Ok(res) } } From 08617399f9a60a8345acbe06c39d26daddfe3e17 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 18 Jan 2023 22:08:31 -0600 Subject: [PATCH 05/18] Go back to old API, improve comments, refactors --- datafusion/common/src/utils.rs | 125 ++++++++++-------- .../physical-expr/src/window/built_in.rs | 7 +- .../src/window/window_frame_state.rs | 16 +-- 3 files changed, 79 insertions(+), 69 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 3d487eecf172..d5c0cfcf3432 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -22,16 +22,16 @@ use arrow::array::ArrayRef; use arrow::compute::SortOptions; use std::cmp::Ordering; -/// Given column vectors, returns row at `idx` -fn get_row_at_idx(item_columns: &[ArrayRef], idx: usize) -> Result> { - item_columns +/// Given column vectors, returns row at `idx`. +fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result> { + columns .iter() .map(|arr| ScalarValue::try_from_array(arr, idx)) - .collect::>>() + .collect() } /// This function compares two tuples depending on the given sort options. -fn compare( +pub fn compare_rows( x: &[ScalarValue], y: &[ScalarValue], sort_options: &[SortOptions], @@ -60,28 +60,35 @@ fn compare( Ok(Ordering::Equal) } -/// This function implements both bisect_left and bisect_right, having the same -/// semantics with the Python Standard Library. To use bisect_left, supply true -/// as the template argument. To use bisect_right, supply false as the template argument. -/// It searches `item_columns` between rows `low` and `high`. +/// This function searches for a tuple of given values (`target`) among the given +/// rows (`item_columns`) using the bisection algorithm. It assumes that `item_columns` +/// is sorted according to `sort_options` and returns the insertion index of `target`. +/// Template argument `SIDE` being `true`/`false` means left/right insertion. pub fn bisect( item_columns: &[ArrayRef], target: &[ScalarValue], sort_options: &[SortOptions], - low: usize, - high: usize, ) -> Result { + let low: usize = 0; + let high: usize = item_columns + .get(0) + .ok_or_else(|| { + DataFusionError::Internal("Column array shouldn't be empty".to_string()) + })? + .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare(current, target, sort_options)?; + let cmp = compare_rows(current, target, sort_options)?; Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) }; find_bisect_point(item_columns, target, compare_fn, low, high) } -/// This function searches for a tuple of target values among the given rows using the bisection algorithm. -/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`), -/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively -/// bisect the input. +/// This function searches for a tuple of given values (`target`) among a slice of +/// the given rows (`item_columns`) using the bisection algorithm. The slice starts +/// at the index `low` and ends at the index `high`. The boolean-valued function +/// `compare_fn` specifies whether we bisect on the left (by returning `false`), +/// or on the right (by returning `true`) when we compare the target value with +/// the current value as we iteratively bisect the input. pub fn find_bisect_point( item_columns: &[ArrayRef], target: &[ScalarValue], @@ -104,22 +111,43 @@ where Ok(low) } -/// This function implements linear_search, It searches `item_columns` between rows `low` and `high`. -/// It assumes `item_columns` is sorted according to `sort_options` -/// and returns insertion position of the `target` in the `item_columns`. -/// `SIDE` is `true` means left insertion is applied. -/// `SIDE` is `false` means right insertion is applied. +/// This function searches for a tuple of given values (`target`) among the given +/// rows (`item_columns`) via a linear scan. It assumes that `item_columns` is sorted +/// according to `sort_options` and returns the insertion index of `target`. +/// Template argument `SIDE` being `true`/`false` means left/right insertion. pub fn linear_search( item_columns: &[ArrayRef], target: &[ScalarValue], sort_options: &[SortOptions], - mut low: usize, - high: usize, ) -> Result { + let low: usize = 0; + let high: usize = item_columns + .get(0) + .ok_or_else(|| { + DataFusionError::Internal("Column array shouldn't be empty".to_string()) + })? + .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare(current, target, sort_options)?; - Ok::(if SIDE { cmp.is_lt() } else { cmp.is_le() }) + let cmp = compare_rows(current, target, sort_options)?; + Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) }; + search_in_slice(item_columns, target, compare_fn, low, high) +} + +/// This function searches for a tuple of given values (`target`) among a slice of +/// the given rows (`item_columns`) via a linear scan. The slice starts at the index +/// `low` and ends at the index `high`. The boolean-valued function `compare_fn` +/// specifies the stopping criterion. +pub fn search_in_slice( + item_columns: &[ArrayRef], + target: &[ScalarValue], + compare_fn: F, + mut low: usize, + high: usize, +) -> Result +where + F: Fn(&[ScalarValue], &[ScalarValue]) -> Result, +{ while low < high { let val = get_row_at_idx(item_columns, low)?; if !compare_fn(&val, target)? { @@ -132,7 +160,7 @@ pub fn linear_search( #[cfg(test)] mod tests { - use arrow::array::{Array, Float64Array}; + use arrow::array::Float64Array; use std::sync::Arc; use crate::from_slice::FromSlice; @@ -173,14 +201,13 @@ mod tests { nulls_first: true, }, ]; - let n_row = arrays[0].len(); - let res: usize = bisect::(&arrays, &search_tuple, &ords, 0, n_row)?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); - let res: usize = bisect::(&arrays, &search_tuple, &ords, 0, n_row)?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 3); - let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0, n_row)?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); - let res: usize = linear_search::(&arrays, &search_tuple, &ords, 0, n_row)?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 3); Ok(()) } @@ -229,11 +256,9 @@ mod tests { descending: true, nulls_first: true, }]; - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 0); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 0); // Descending, right @@ -245,11 +270,9 @@ mod tests { descending: true, nulls_first: true, }]; - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 1); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 1); // Ascending, left @@ -260,11 +283,9 @@ mod tests { descending: false, nulls_first: true, }]; - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 1); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 1); // Ascending, right @@ -275,11 +296,9 @@ mod tests { descending: false, nulls_first: true, }]; - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); let arrays: Vec = vec![ @@ -300,18 +319,14 @@ mod tests { nulls_first: true, }, ]; - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 3); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 3); - let res: usize = - bisect::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = bisect::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); - let res: usize = - linear_search::(&arrays, &search_tuple, &ords, 0, arrays[0].len())?; + let res = linear_search::(&arrays, &search_tuple, &ords)?; assert_eq!(res, 2); Ok(()) } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index bfb8e835b84d..b73e2b8dedac 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -102,13 +102,12 @@ impl WindowExpr for BuiltInWindowExpr { self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; - let length = batch.num_rows(); let (values, order_bys) = self.get_values_orderbys(batch)?; let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let mut range = Range { start: 0, end: 0 }; + let range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result - for idx in 0..length { - range = window_frame_ctx.calculate_range( + for idx in 0..num_rows { + let range = window_frame_ctx.calculate_range( &order_bys, &sort_options, num_rows, diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 7e31a94f31df..02e765a01166 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{find_bisect_point, linear_search}; +use datafusion_common::utils::{compare_rows, find_bisect_point, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -307,15 +307,11 @@ impl WindowFrameStateRange { } else { last_range.end }; - // `SIDE` true means from left insert, false means right insert - let res = linear_search::( - range_columns, - &end_range, - sort_options, - search_start, - length, - )?; - Ok(res) + let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { + let cmp = compare_rows(current, target, sort_options)?; + Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) + }; + search_in_slice(range_columns, &end_range, compare_fn, search_start, length) } } From d1a53ac6a17e66ec13a64e88bbadf2efcf520c24 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 19 Jan 2023 16:29:44 +0300 Subject: [PATCH 06/18] Linear Groups implementation --- datafusion/common/src/utils.rs | 22 +- .../core/src/physical_optimizer/test_utils.rs | 7 +- datafusion/core/tests/sql/window.rs | 35 + .../src/window/window_frame_state.rs | 1123 ++++++++--------- 4 files changed, 565 insertions(+), 622 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index d5c0cfcf3432..42924348b1ba 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -23,7 +23,7 @@ use arrow::compute::SortOptions; use std::cmp::Ordering; /// Given column vectors, returns row at `idx`. -fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result> { +pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result> { columns .iter() .map(|arr| ScalarValue::try_from_array(arr, idx)) @@ -158,6 +158,26 @@ where Ok(low) } +/// This function searches for a tuple of given values (`target`) among a slice of +/// the given rows (`item_columns`) via a linear scan. The slice starts at the index +/// `low` and ends at the index `high`. The boolean-valued function `compare_fn` +/// specifies the stopping criterion. +pub fn search_till_change( + item_columns: &[ArrayRef], + mut low: usize, + high: usize, +) -> Result { + let start_row = get_row_at_idx(item_columns, low)?; + while low < high { + let val = get_row_at_idx(item_columns, low)?; + if !start_row.eq(&val) { + break; + } + low += 1; + } + Ok(low) +} + #[cfg(test)] mod tests { use arrow::array::Float64Array; diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 1404dfa20c30..8689b016b01c 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -118,7 +118,12 @@ impl QueryCase { if error.is_some() { let plan_error = plan.unwrap_err(); let initial = error.unwrap().to_string(); - assert!(plan_error.to_string().contains(initial.as_str())); + assert!( + plan_error.to_string().contains(initial.as_str()), + "plan_error: {:?} doesn't contain message: {:?}", + plan_error, + initial.as_str() + ); } else { assert!(plan.is_ok()) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 1167d57a4ffb..3ec65596fcc7 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2631,3 +2631,38 @@ mod tests { Ok(()) } } + +#[tokio::test] +#[ignore] +async fn window_frame_groups_preceding_following_desc_v2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + SUM(c9) OVER (ORDER BY c2, c3 GROUPS BETWEEN 7 PRECEDING AND 5 PRECEDING) as a13 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + // "+------------+-------------+------------+", + // "| a11 | a12 | a13 |", + // "+------------+-------------+------------+", + // "| 8762794441 | 23265575240 | 4216440507 |", + // "| 9816952706 | 32247727190 | 4216440507 |", + // "| | 16948357027 | 4216440507 |", + // "| 5488534936 | 35475180026 | 4216440507 |", + // "| 4251998668 | 26703253620 | 4216440507 |", + // "+------------+-------------+------------+", + "+------------+-------------+-------------+", + "| a11 | a12 | a13 |", + "+------------+-------------+-------------+", + "| 8762794441 | 23265575240 | 5838012327 |", + "| 9816952706 | 32247727190 | 6763234295 |", + "| | 16948357027 | 7569829821 |", + "| 5488534936 | 35475180026 | 12012095561 |", + "| 4251998668 | 26703253620 | 12027257685 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 02e765a01166..a7746505d7e0 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -20,7 +20,9 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{compare_rows, find_bisect_point, search_in_slice}; +use datafusion_common::utils::{ + compare_rows, find_bisect_point, get_row_at_idx, search_in_slice, search_till_change, +}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -58,6 +60,7 @@ impl<'a> WindowFrameContext<'a> { WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, state: WindowFrameStateGroups::default(), + // state: WindowFrameStateGroups::default(), }, } } @@ -93,7 +96,14 @@ impl<'a> WindowFrameContext<'a> { WindowFrameContext::Groups { window_frame, ref mut state, - } => state.calculate_range(window_frame, range_columns, length, idx), + } => state.calculate_range( + window_frame, + range_columns, + sort_options, + length, + idx, + last_range, + ), } } @@ -341,22 +351,19 @@ impl WindowFrameStateRange { // scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { - current_group_idx: u64, group_start_indices: VecDeque<(Vec, usize)>, - previous_row_values: Option>, - reached_end: bool, - window_frame_end_idx: u64, - window_frame_start_idx: u64, + cur_group_idx: usize, } impl WindowFrameStateGroups { - /// This function calculates beginning/ending indices for the frame of the current row. fn calculate_range( &mut self, window_frame: &Arc, range_columns: &[ArrayRef], + sort_options: &[SortOptions], length: usize, idx: usize, + last_range: &Range, ) -> Result> { if range_columns.is_empty() { return Err(DataFusionError::Execution( @@ -364,656 +371,532 @@ impl WindowFrameStateGroups { )); } let start = match window_frame.start_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - WindowFrameBound::CurrentRow => self.calculate_index_of_group::( - range_columns, - idx, - 0, - length, - )?, - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( - "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'" - ))) - } - // ERRONEOUS FRAMES - WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal( - "Groups should be Uint".to_string(), - )) + WindowFrameBound::Preceding(ref n) => { + if n.is_null() { + // UNBOUNDED PRECEDING + 0 + } else { + self.calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + last_range, + length, + )? + } } - }; - let end = match window_frame.end_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( - "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'" - ))) + WindowFrameBound::CurrentRow => { + if range_columns.is_empty() { + 0 + } else { + self.calculate_index_of_row::( + range_columns, + sort_options, + idx, + None, + last_range, + length, + )? + } } - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::(range_columns, idx, n, length)?, - WindowFrameBound::CurrentRow => self - .calculate_index_of_group::( + WindowFrameBound::Following(ref n) => self + .calculate_index_of_row::( range_columns, + sort_options, idx, - 0, + Some(n), + last_range, length, )?, - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => self - .calculate_index_of_group::( + }; + let end = match window_frame.end_bound { + WindowFrameBound::Preceding(ref n) => self + .calculate_index_of_row::( range_columns, + sort_options, idx, - n, + Some(n), + last_range, length, )?, - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, - // ERRONEOUS FRAMES - WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal( - "Groups should be Uint".to_string(), - )) + WindowFrameBound::CurrentRow => { + if range_columns.is_empty() { + length + } else { + self.calculate_index_of_row::( + range_columns, + sort_options, + idx, + None, + last_range, + length, + )? + } + } + WindowFrameBound::Following(ref n) => { + if n.is_null() { + // UNBOUNDED FOLLOWING + length + } else { + self.calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + last_range, + length, + )? + } } }; Ok(Range { start, end }) } - /// This function does the heavy lifting when finding group boundaries. It is meant to be - /// called twice, in succession, to get window frame start and end indices (with `BISECT_SIDE` - /// supplied as false and true, respectively). - fn calculate_index_of_group( + /// This function does the heavy lifting when finding range boundaries. It is meant to be + /// called twice, in succession, to get window frame start and end indices (with `SIDE` + /// supplied as true and false, respectively). + fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], + sort_options: &[SortOptions], idx: usize, - delta: u64, + delta: Option<&ScalarValue>, + last_range: &Range, length: usize, ) -> Result { - let current_row_values = range_columns - .iter() - .map(|col| ScalarValue::try_from_array(col, idx)) - .collect::>>()?; - - if BISECT_SIDE { - // When we call this function to get the window frame start index, it tries to initialize - // the internal grouping state if this is not already done before. This initialization takes - // place only when the window frame start index is greater than or equal to zero. In this - // case, the current row is stored in group_start_indices, with row values as the group - // identifier and row index as the start index of the group. - if !self.initialized() { - self.initialize::(delta, range_columns)?; - } - } else if !self.reached_end { - // When we call this function to get the window frame end index, it extends the window - // frame one by one until the current row's window frame end index is reached by finding - // the next group. - self.extend_window_frame_if_necessary::(range_columns, delta)?; - } - // We keep track of previous row values, so that a group change can be identified. - // If there is a group change, the window frame is advanced and shifted by one group. - let group_change = match &self.previous_row_values { - None => false, - Some(values) => ¤t_row_values != values, - }; - if self.previous_row_values.is_none() || group_change { - self.previous_row_values = Some(current_row_values); - } - if group_change { - self.current_group_idx += 1; - self.advance_one_group::(range_columns)?; - self.shift_one_group::(delta); - } - Ok(if self.group_start_indices.is_empty() { - if self.reached_end { - length - } else { - 0 - } - } else if BISECT_SIDE { - match self.group_start_indices.get(0) { - Some(&(_, idx)) => idx, - None => 0, - } + let search_start = if SIDE { + last_range.start } else { - match (self.reached_end, self.group_start_indices.back()) { - (false, Some(&(_, idx))) => idx, - _ => length, + last_range.end + }; + let delta = if let Some(delta) = delta { + match delta { + ScalarValue::UInt64(Some(val)) => Ok(*val as usize), + _ => Err(DataFusionError::Execution("expects uint64".to_string())), } - }) - } - - fn extend_window_frame_if_necessary( - &mut self, - range_columns: &[ArrayRef], - delta: u64, - ) -> Result<()> { - let current_window_frame_end_idx = if !SEARCH_SIDE { - self.current_group_idx + delta + 1 - } else if self.current_group_idx >= delta { - self.current_group_idx - delta + 1 } else { - 0 - }; - if current_window_frame_end_idx == 0 { - // the end index of the window frame is still before the first index - return Ok(()); - } - if self.group_start_indices.is_empty() { - self.initialize_window_frame_start(range_columns)?; + Ok(0) + }?; + let mut change_idx = search_start; + let last_group = self.group_start_indices.back(); + if let Some((_last_group, bound_idx)) = last_group { + change_idx = *bound_idx; } - while !self.reached_end - && self.window_frame_end_idx <= current_window_frame_end_idx - { - self.advance_one_group::(range_columns)?; - } - Ok(()) - } - fn initialize( - &mut self, - delta: u64, - range_columns: &[ArrayRef], - ) -> Result<()> { - if !SEARCH_SIDE { - self.window_frame_start_idx = self.current_group_idx + delta; - self.initialize_window_frame_start(range_columns) - } else if self.current_group_idx >= delta { - self.window_frame_start_idx = self.current_group_idx - delta; - self.initialize_window_frame_start(range_columns) - } else { - Ok(()) + while idx > change_idx { + let group_row = get_row_at_idx(range_columns, change_idx)?; + change_idx = search_till_change(range_columns, change_idx, length)?; + self.group_start_indices.push_back((group_row, change_idx)); } - } - fn initialize_window_frame_start( - &mut self, - range_columns: &[ArrayRef], - ) -> Result<()> { - let mut group_values = range_columns - .iter() - .map(|col| ScalarValue::try_from_array(col, 0)) - .collect::>>()?; - let mut start_idx: usize = 0; - for _ in 0..self.window_frame_start_idx { - let next_group_and_start_index = - WindowFrameStateGroups::find_next_group_and_start_index( - range_columns, - &group_values, - start_idx, - )?; - if let Some(entry) = next_group_and_start_index { - (group_values, start_idx) = entry; + while self.cur_group_idx < self.group_start_indices.len() { + if idx >= self.group_start_indices[self.cur_group_idx].1 { + self.cur_group_idx += 1; } else { - // not enough groups to generate a window frame - self.window_frame_end_idx = self.window_frame_start_idx; - self.reached_end = true; - return Ok(()); + break; } } - self.group_start_indices - .push_back((group_values, start_idx)); - self.window_frame_end_idx = self.window_frame_start_idx + 1; - Ok(()) - } - - fn initialized(&self) -> bool { - self.reached_end || !self.group_start_indices.is_empty() - } - - /// This function advances the window frame by one group. - fn advance_one_group( - &mut self, - range_columns: &[ArrayRef], - ) -> Result<()> { - let last_group_values = self.group_start_indices.back(); - let last_group_values = if let Some(values) = last_group_values { - values + let group_idx = if PRECEDING { + if self.cur_group_idx > delta { + self.cur_group_idx - delta + } else { + 0 + } } else { - return Ok(()); + self.cur_group_idx + delta }; - let next_group_and_start_index = - WindowFrameStateGroups::find_next_group_and_start_index( - range_columns, - &last_group_values.0, - last_group_values.1, - )?; - if let Some(entry) = next_group_and_start_index { - self.group_start_indices.push_back(entry); - self.window_frame_end_idx += 1; - } else { - // not enough groups to proceed - self.reached_end = true; - } - Ok(()) - } - /// This function drops the oldest group from the window frame. - fn shift_one_group(&mut self, delta: u64) { - let current_window_frame_start_idx = if !SEARCH_SIDE { - self.current_group_idx + delta - } else if self.current_group_idx >= delta { - self.current_group_idx - delta - } else { - 0 - }; - if current_window_frame_start_idx > self.window_frame_start_idx { - self.group_start_indices.pop_front(); - self.window_frame_start_idx += 1; + while self.group_start_indices.len() <= group_idx && change_idx < length { + // println!("change idx: {:?}, idx: {:?}, self.cur_group_idx: {:?}", change_idx, idx, self.cur_group_idx); + let group_row = get_row_at_idx(range_columns, change_idx)?; + change_idx = search_till_change(range_columns, change_idx, length)?; + self.group_start_indices.push_back((group_row, change_idx)); } - } - /// This function finds the next group and its start index for a given group and start index. - /// It utilizes an exponentially growing step size to find the group boundary. - // TODO: For small group sizes, proceeding one-by-one to find the group change can be more efficient. - // Statistics about previous group sizes can be used to choose one-by-one vs. exponentially growing, - // or even to set the base step_size when exponentially growing. We can also create a benchmark - // implementation to get insights about the crossover point. - fn find_next_group_and_start_index( - range_columns: &[ArrayRef], - current_row_values: &[ScalarValue], - idx: usize, - ) -> Result, usize)>> { - let mut step_size: usize = 1; - let data_size: usize = range_columns - .get(0) - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? - .len(); - let mut low = idx; - let mut high = idx + step_size; - while high < data_size { - let val = range_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, high)) - .collect::>>()?; - if val == current_row_values { - low = high; - step_size *= 2; - high += step_size; + let res = if SIDE { + let group_idx = min(group_idx, self.group_start_indices.len()); + if group_idx > 0 { + self.group_start_indices[group_idx - 1].1 } else { - break; - } - } - low = find_bisect_point( - range_columns, - current_row_values, - |current, to_compare| Ok(current == to_compare), - low, - min(high, data_size), - )?; - if low == data_size { - return Ok(None); - } - let val = range_columns - .iter() - .map(|arr| ScalarValue::try_from_array(arr, low)) - .collect::>>()?; - Ok(Some((val, low))) - } -} - -#[cfg(test)] -mod tests { - use arrow::array::Float64Array; - use datafusion_common::ScalarValue; - use std::sync::Arc; - - use crate::from_slice::FromSlice; - - use super::*; - - struct TestData { - arrays: Vec, - group_indices: [usize; 6], - num_groups: usize, - num_rows: usize, - next_group_indices: [usize; 5], - } - - fn test_data() -> TestData { - let num_groups: usize = 5; - let num_rows: usize = 6; - let group_indices = [0, 1, 2, 2, 4, 5]; - let next_group_indices = [1, 2, 4, 4, 5]; - - let arrays: Vec = vec![ - Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 9., 10.])), - Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 3., 4.0, 5.0])), - Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 10., 11.0])), - Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 8., 5., 0.0])), - ]; - TestData { - arrays, - group_indices, - num_groups, - num_rows, - next_group_indices, - } - } - - #[test] - fn test_find_next_group_and_start_index() { - let test_data = test_data(); - for (current_idx, next_idx) in test_data.next_group_indices.iter().enumerate() { - let current_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, current_idx)) - .collect::>>() - .unwrap(); - let next_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, *next_idx)) - .collect::>>() - .unwrap(); - let res = WindowFrameStateGroups::find_next_group_and_start_index( - &test_data.arrays, - ¤t_row_values, - current_idx, - ) - .unwrap(); - assert_eq!(res, Some((next_row_values, *next_idx))); - } - let current_idx = test_data.num_rows - 1; - let current_row_values = test_data - .arrays - .iter() - .map(|col| ScalarValue::try_from_array(col, current_idx)) - .collect::>>() - .unwrap(); - let res = WindowFrameStateGroups::find_next_group_and_start_index( - &test_data.arrays, - ¤t_row_values, - current_idx, - ) - .unwrap(); - assert_eq!(res, None); - } - - #[test] - fn test_window_frame_groups_preceding_delta_greater_than_partition_size() { - const START: bool = true; - const END: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 10; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, 0); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, 0); - } - } - - #[test] - fn test_window_frame_groups_following_delta_greater_than_partition_size() { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const DELTA: u64 = 10; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); - assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); - assert!(window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); - assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); - assert!(window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.num_rows); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, test_data.num_rows); - } - } - - #[test] - fn test_window_frame_groups_preceding_and_following_delta_greater_than_partition_size( - ) { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 10; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!( - window_frame_groups.window_frame_end_idx, - test_data.num_groups as u64 - ); - assert!(window_frame_groups.reached_end); - assert_eq!( - window_frame_groups.group_start_indices.len(), - test_data.num_groups - ); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, 0); - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, test_data.num_rows); - } - } - - #[test] - fn test_window_frame_groups_preceding_and_following_1() { - const START: bool = true; - const END: bool = false; - const FOLLOWING: bool = false; - const PRECEDING: bool = true; - const DELTA: u64 = 1; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - window_frame_groups - .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 2 * DELTA + 1); - assert!(!window_frame_groups.reached_end); - assert_eq!( - window_frame_groups.group_start_indices.len(), - 2 * DELTA as usize + 1 - ); - - for idx in 0..test_data.num_rows { - let start_idx = if idx < DELTA as usize { 0 - } else { - test_data.group_indices[idx] - DELTA as usize - }; - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[start_idx]); - let mut end_idx = if idx >= test_data.num_groups { - test_data.num_rows - } else { - test_data.next_group_indices[idx] - }; - for _ in 0..DELTA { - end_idx = if end_idx >= test_data.num_groups { - test_data.num_rows - } else { - test_data.next_group_indices[end_idx] - }; } - let end = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(end, end_idx); - } - } - - #[test] - fn test_window_frame_groups_preceding_1_and_unbounded_following() { - const START: bool = true; - const PRECEDING: bool = true; - const DELTA: u64 = 1; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 0); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 0); - - for idx in 0..test_data.num_rows { - let start_idx = if idx < DELTA as usize { - 0 + } else { + if PRECEDING { + if self.cur_group_idx >= delta { + let group_idx = self.cur_group_idx - delta; + self.group_start_indices[group_idx].1 + } else { + 0 + } } else { - test_data.group_indices[idx] - DELTA as usize - }; - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[start_idx]); - } - } - - #[test] - fn test_window_frame_groups_current_and_unbounded_following() { - const START: bool = true; - const PRECEDING: bool = true; - const DELTA: u64 = 0; - - let test_data = test_data(); - let mut window_frame_groups = WindowFrameStateGroups::default(); - window_frame_groups - .initialize::(DELTA, &test_data.arrays) - .unwrap(); - assert_eq!(window_frame_groups.window_frame_start_idx, 0); - assert_eq!(window_frame_groups.window_frame_end_idx, 1); - assert!(!window_frame_groups.reached_end); - assert_eq!(window_frame_groups.group_start_indices.len(), 1); - - for idx in 0..test_data.num_rows { - let start = window_frame_groups - .calculate_index_of_group::( - &test_data.arrays, - idx, - DELTA, - test_data.num_rows, - ) - .unwrap(); - assert_eq!(start, test_data.group_indices[idx]); - } + let group_idx = min( + self.cur_group_idx + delta, + self.group_start_indices.len() - 1, + ); + self.group_start_indices[group_idx].1 + } + }; + Ok(res) } } + +// #[cfg(test)] +// mod tests { +// use arrow::array::Float64Array; +// use datafusion_common::ScalarValue; +// use std::sync::Arc; +// +// use crate::from_slice::FromSlice; +// +// use super::*; +// +// struct TestData { +// arrays: Vec, +// group_indices: [usize; 6], +// num_groups: usize, +// num_rows: usize, +// next_group_indices: [usize; 5], +// } +// +// fn test_data() -> TestData { +// let num_groups: usize = 5; +// let num_rows: usize = 6; +// let group_indices = [0, 1, 2, 2, 4, 5]; +// let next_group_indices = [1, 2, 4, 4, 5]; +// +// let arrays: Vec = vec![ +// Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 9., 10.])), +// Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 3., 4.0, 5.0])), +// Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 10., 11.0])), +// Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 8., 5., 0.0])), +// ]; +// TestData { +// arrays, +// group_indices, +// num_groups, +// num_rows, +// next_group_indices, +// } +// } +// +// #[test] +// fn test_find_next_group_and_start_index() { +// let test_data = test_data(); +// for (current_idx, next_idx) in test_data.next_group_indices.iter().enumerate() { +// let current_row_values = test_data +// .arrays +// .iter() +// .map(|col| ScalarValue::try_from_array(col, current_idx)) +// .collect::>>() +// .unwrap(); +// let next_row_values = test_data +// .arrays +// .iter() +// .map(|col| ScalarValue::try_from_array(col, *next_idx)) +// .collect::>>() +// .unwrap(); +// let res = WindowFrameStateGroups::find_next_group_and_start_index( +// &test_data.arrays, +// ¤t_row_values, +// current_idx, +// ) +// .unwrap(); +// assert_eq!(res, Some((next_row_values, *next_idx))); +// } +// let current_idx = test_data.num_rows - 1; +// let current_row_values = test_data +// .arrays +// .iter() +// .map(|col| ScalarValue::try_from_array(col, current_idx)) +// .collect::>>() +// .unwrap(); +// let res = WindowFrameStateGroups::find_next_group_and_start_index( +// &test_data.arrays, +// ¤t_row_values, +// current_idx, +// ) +// .unwrap(); +// assert_eq!(res, None); +// } +// +// #[test] +// fn test_window_frame_groups_preceding_delta_greater_than_partition_size() { +// const START: bool = true; +// const END: bool = false; +// const PRECEDING: bool = true; +// const DELTA: u64 = 10; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 0); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// window_frame_groups +// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 0); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// for idx in 0..test_data.num_rows { +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, 0); +// let end = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(end, 0); +// } +// } +// +// #[test] +// fn test_window_frame_groups_following_delta_greater_than_partition_size() { +// const START: bool = true; +// const END: bool = false; +// const FOLLOWING: bool = false; +// const DELTA: u64 = 10; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); +// assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); +// assert!(window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// window_frame_groups +// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); +// assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); +// assert!(window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// for idx in 0..test_data.num_rows { +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, test_data.num_rows); +// let end = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(end, test_data.num_rows); +// } +// } +// +// #[test] +// fn test_window_frame_groups_preceding_and_following_delta_greater_than_partition_size( +// ) { +// const START: bool = true; +// const END: bool = false; +// const FOLLOWING: bool = false; +// const PRECEDING: bool = true; +// const DELTA: u64 = 10; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 0); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// window_frame_groups +// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!( +// window_frame_groups.window_frame_end_idx, +// test_data.num_groups as u64 +// ); +// assert!(window_frame_groups.reached_end); +// assert_eq!( +// window_frame_groups.group_start_indices.len(), +// test_data.num_groups +// ); +// +// for idx in 0..test_data.num_rows { +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, 0); +// let end = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(end, test_data.num_rows); +// } +// } +// +// #[test] +// fn test_window_frame_groups_preceding_and_following_1() { +// const START: bool = true; +// const END: bool = false; +// const FOLLOWING: bool = false; +// const PRECEDING: bool = true; +// const DELTA: u64 = 1; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 0); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// window_frame_groups +// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 2 * DELTA + 1); +// assert!(!window_frame_groups.reached_end); +// assert_eq!( +// window_frame_groups.group_start_indices.len(), +// 2 * DELTA as usize + 1 +// ); +// +// for idx in 0..test_data.num_rows { +// let start_idx = if idx < DELTA as usize { +// 0 +// } else { +// test_data.group_indices[idx] - DELTA as usize +// }; +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, test_data.group_indices[start_idx]); +// let mut end_idx = if idx >= test_data.num_groups { +// test_data.num_rows +// } else { +// test_data.next_group_indices[idx] +// }; +// for _ in 0..DELTA { +// end_idx = if end_idx >= test_data.num_groups { +// test_data.num_rows +// } else { +// test_data.next_group_indices[end_idx] +// }; +// } +// let end = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(end, end_idx); +// } +// } +// +// #[test] +// fn test_window_frame_groups_preceding_1_and_unbounded_following() { +// const START: bool = true; +// const PRECEDING: bool = true; +// const DELTA: u64 = 1; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 0); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 0); +// +// for idx in 0..test_data.num_rows { +// let start_idx = if idx < DELTA as usize { +// 0 +// } else { +// test_data.group_indices[idx] - DELTA as usize +// }; +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, test_data.group_indices[start_idx]); +// } +// } +// +// #[test] +// fn test_window_frame_groups_current_and_unbounded_following() { +// const START: bool = true; +// const PRECEDING: bool = true; +// const DELTA: u64 = 0; +// +// let test_data = test_data(); +// let mut window_frame_groups = WindowFrameStateGroups::default(); +// window_frame_groups +// .initialize::(DELTA, &test_data.arrays) +// .unwrap(); +// assert_eq!(window_frame_groups.window_frame_start_idx, 0); +// assert_eq!(window_frame_groups.window_frame_end_idx, 1); +// assert!(!window_frame_groups.reached_end); +// assert_eq!(window_frame_groups.group_start_indices.len(), 1); +// +// for idx in 0..test_data.num_rows { +// let start = window_frame_groups +// .calculate_index_of_group::( +// &test_data.arrays, +// idx, +// DELTA, +// test_data.num_rows, +// ) +// .unwrap(); +// assert_eq!(start, test_data.group_indices[idx]); +// } +// } +// } From 46210c4007eb364e1462d328f1580cf881e1c769 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 25 Jan 2023 15:50:09 +0300 Subject: [PATCH 07/18] Resolve linter errors --- .../core/src/physical_optimizer/test_utils.rs | 7 +--- datafusion/core/tests/sql/window.rs | 35 ------------------- .../src/window/window_frame_state.rs | 33 +++++++++-------- 3 files changed, 20 insertions(+), 55 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 8689b016b01c..1404dfa20c30 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -118,12 +118,7 @@ impl QueryCase { if error.is_some() { let plan_error = plan.unwrap_err(); let initial = error.unwrap().to_string(); - assert!( - plan_error.to_string().contains(initial.as_str()), - "plan_error: {:?} doesn't contain message: {:?}", - plan_error, - initial.as_str() - ); + assert!(plan_error.to_string().contains(initial.as_str())); } else { assert!(plan.is_ok()) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3ec65596fcc7..1167d57a4ffb 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2631,38 +2631,3 @@ mod tests { Ok(()) } } - -#[tokio::test] -#[ignore] -async fn window_frame_groups_preceding_following_desc_v2() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT - SUM(c9) OVER (ORDER BY c2, c3 GROUPS BETWEEN 7 PRECEDING AND 5 PRECEDING) as a13 - FROM aggregate_test_100 - ORDER BY c9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - // "+------------+-------------+------------+", - // "| a11 | a12 | a13 |", - // "+------------+-------------+------------+", - // "| 8762794441 | 23265575240 | 4216440507 |", - // "| 9816952706 | 32247727190 | 4216440507 |", - // "| | 16948357027 | 4216440507 |", - // "| 5488534936 | 35475180026 | 4216440507 |", - // "| 4251998668 | 26703253620 | 4216440507 |", - // "+------------+-------------+------------+", - "+------------+-------------+-------------+", - "| a11 | a12 | a13 |", - "+------------+-------------+-------------+", - "| 8762794441 | 23265575240 | 5838012327 |", - "| 9816952706 | 32247727190 | 6763234295 |", - "| | 16948357027 | 7569829821 |", - "| 5488534936 | 35475180026 | 12012095561 |", - "| 4251998668 | 26703253620 | 12027257685 |", - "+------------+-------------+-------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index a7746505d7e0..0cc1326f304b 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -21,7 +21,7 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; use datafusion_common::utils::{ - compare_rows, find_bisect_point, get_row_at_idx, search_in_slice, search_till_change, + compare_rows, get_row_at_idx, search_in_slice, search_till_change, }; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; @@ -352,6 +352,7 @@ impl WindowFrameStateRange { #[derive(Debug, Default)] pub struct WindowFrameStateGroups { group_start_indices: VecDeque<(Vec, usize)>, + // Keeps the groups index that row index belongs cur_group_idx: usize, } @@ -459,7 +460,7 @@ impl WindowFrameStateGroups { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], + _sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, last_range: &Range, @@ -481,15 +482,18 @@ impl WindowFrameStateGroups { let mut change_idx = search_start; let last_group = self.group_start_indices.back(); if let Some((_last_group, bound_idx)) = last_group { + // Start searching from change point from last boundary change_idx = *bound_idx; } + // Progress groups until idx is inside a group while idx > change_idx { let group_row = get_row_at_idx(range_columns, change_idx)?; change_idx = search_till_change(range_columns, change_idx, length)?; self.group_start_indices.push_back((group_row, change_idx)); } + // Update the group index `idx` belongs. while self.cur_group_idx < self.group_start_indices.len() { if idx >= self.group_start_indices[self.cur_group_idx].1 { self.cur_group_idx += 1; @@ -507,37 +511,38 @@ impl WindowFrameStateGroups { self.cur_group_idx + delta }; + // Expand group_start_indices until it includes at least group_idx while self.group_start_indices.len() <= group_idx && change_idx < length { // println!("change idx: {:?}, idx: {:?}, self.cur_group_idx: {:?}", change_idx, idx, self.cur_group_idx); let group_row = get_row_at_idx(range_columns, change_idx)?; change_idx = search_till_change(range_columns, change_idx, length)?; self.group_start_indices.push_back((group_row, change_idx)); } - - let res = if SIDE { - let group_idx = min(group_idx, self.group_start_indices.len()); - if group_idx > 0 { - self.group_start_indices[group_idx - 1].1 - } else { - 0 + Ok(match (SIDE, PRECEDING) { + (true, _) => { + let group_idx = min(group_idx, self.group_start_indices.len()); + if group_idx > 0 { + self.group_start_indices[group_idx - 1].1 + } else { + 0 + } } - } else { - if PRECEDING { + (false, true) => { if self.cur_group_idx >= delta { let group_idx = self.cur_group_idx - delta; self.group_start_indices[group_idx].1 } else { 0 } - } else { + } + (false, false) => { let group_idx = min( self.cur_group_idx + delta, self.group_start_indices.len() - 1, ); self.group_start_indices[group_idx].1 } - }; - Ok(res) + }) } } From 2c9a75629c33ed338f4b74839b07144c96a2a774 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 6 Feb 2023 11:01:57 +0300 Subject: [PATCH 08/18] remove old unit tests --- .../src/window/window_frame_state.rs | 361 ------------------ 1 file changed, 361 deletions(-) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 2f5d0e50926d..33850742e4d3 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -60,7 +60,6 @@ impl<'a> WindowFrameContext<'a> { WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, state: WindowFrameStateGroups::default(), - // state: WindowFrameStateGroups::default(), }, } } @@ -542,363 +541,3 @@ impl WindowFrameStateGroups { }) } } - -// #[cfg(test)] -// mod tests { -// use arrow::array::Float64Array; -// use datafusion_common::ScalarValue; -// use std::sync::Arc; -// -// use crate::from_slice::FromSlice; -// -// use super::*; -// -// struct TestData { -// arrays: Vec, -// group_indices: [usize; 6], -// num_groups: usize, -// num_rows: usize, -// next_group_indices: [usize; 5], -// } -// -// fn test_data() -> TestData { -// let num_groups: usize = 5; -// let num_rows: usize = 6; -// let group_indices = [0, 1, 2, 2, 4, 5]; -// let next_group_indices = [1, 2, 4, 4, 5]; -// -// let arrays: Vec = vec![ -// Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 9., 10.])), -// Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 3., 4.0, 5.0])), -// Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8., 10., 11.0])), -// Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 8., 5., 0.0])), -// ]; -// TestData { -// arrays, -// group_indices, -// num_groups, -// num_rows, -// next_group_indices, -// } -// } -// -// #[test] -// fn test_find_next_group_and_start_index() { -// let test_data = test_data(); -// for (current_idx, next_idx) in test_data.next_group_indices.iter().enumerate() { -// let current_row_values = test_data -// .arrays -// .iter() -// .map(|col| ScalarValue::try_from_array(col, current_idx)) -// .collect::>>() -// .unwrap(); -// let next_row_values = test_data -// .arrays -// .iter() -// .map(|col| ScalarValue::try_from_array(col, *next_idx)) -// .collect::>>() -// .unwrap(); -// let res = WindowFrameStateGroups::find_next_group_and_start_index( -// &test_data.arrays, -// ¤t_row_values, -// current_idx, -// ) -// .unwrap(); -// assert_eq!(res, Some((next_row_values, *next_idx))); -// } -// let current_idx = test_data.num_rows - 1; -// let current_row_values = test_data -// .arrays -// .iter() -// .map(|col| ScalarValue::try_from_array(col, current_idx)) -// .collect::>>() -// .unwrap(); -// let res = WindowFrameStateGroups::find_next_group_and_start_index( -// &test_data.arrays, -// ¤t_row_values, -// current_idx, -// ) -// .unwrap(); -// assert_eq!(res, None); -// } -// -// #[test] -// fn test_window_frame_groups_preceding_delta_greater_than_partition_size() { -// const START: bool = true; -// const END: bool = false; -// const PRECEDING: bool = true; -// const DELTA: u64 = 10; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 0); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// window_frame_groups -// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 0); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// for idx in 0..test_data.num_rows { -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, 0); -// let end = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(end, 0); -// } -// } -// -// #[test] -// fn test_window_frame_groups_following_delta_greater_than_partition_size() { -// const START: bool = true; -// const END: bool = false; -// const FOLLOWING: bool = false; -// const DELTA: u64 = 10; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); -// assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); -// assert!(window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// window_frame_groups -// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, DELTA); -// assert_eq!(window_frame_groups.window_frame_end_idx, DELTA); -// assert!(window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// for idx in 0..test_data.num_rows { -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, test_data.num_rows); -// let end = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(end, test_data.num_rows); -// } -// } -// -// #[test] -// fn test_window_frame_groups_preceding_and_following_delta_greater_than_partition_size( -// ) { -// const START: bool = true; -// const END: bool = false; -// const FOLLOWING: bool = false; -// const PRECEDING: bool = true; -// const DELTA: u64 = 10; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 0); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// window_frame_groups -// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!( -// window_frame_groups.window_frame_end_idx, -// test_data.num_groups as u64 -// ); -// assert!(window_frame_groups.reached_end); -// assert_eq!( -// window_frame_groups.group_start_indices.len(), -// test_data.num_groups -// ); -// -// for idx in 0..test_data.num_rows { -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, 0); -// let end = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(end, test_data.num_rows); -// } -// } -// -// #[test] -// fn test_window_frame_groups_preceding_and_following_1() { -// const START: bool = true; -// const END: bool = false; -// const FOLLOWING: bool = false; -// const PRECEDING: bool = true; -// const DELTA: u64 = 1; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 0); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// window_frame_groups -// .extend_window_frame_if_necessary::(&test_data.arrays, DELTA) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 2 * DELTA + 1); -// assert!(!window_frame_groups.reached_end); -// assert_eq!( -// window_frame_groups.group_start_indices.len(), -// 2 * DELTA as usize + 1 -// ); -// -// for idx in 0..test_data.num_rows { -// let start_idx = if idx < DELTA as usize { -// 0 -// } else { -// test_data.group_indices[idx] - DELTA as usize -// }; -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, test_data.group_indices[start_idx]); -// let mut end_idx = if idx >= test_data.num_groups { -// test_data.num_rows -// } else { -// test_data.next_group_indices[idx] -// }; -// for _ in 0..DELTA { -// end_idx = if end_idx >= test_data.num_groups { -// test_data.num_rows -// } else { -// test_data.next_group_indices[end_idx] -// }; -// } -// let end = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(end, end_idx); -// } -// } -// -// #[test] -// fn test_window_frame_groups_preceding_1_and_unbounded_following() { -// const START: bool = true; -// const PRECEDING: bool = true; -// const DELTA: u64 = 1; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 0); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 0); -// -// for idx in 0..test_data.num_rows { -// let start_idx = if idx < DELTA as usize { -// 0 -// } else { -// test_data.group_indices[idx] - DELTA as usize -// }; -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, test_data.group_indices[start_idx]); -// } -// } -// -// #[test] -// fn test_window_frame_groups_current_and_unbounded_following() { -// const START: bool = true; -// const PRECEDING: bool = true; -// const DELTA: u64 = 0; -// -// let test_data = test_data(); -// let mut window_frame_groups = WindowFrameStateGroups::default(); -// window_frame_groups -// .initialize::(DELTA, &test_data.arrays) -// .unwrap(); -// assert_eq!(window_frame_groups.window_frame_start_idx, 0); -// assert_eq!(window_frame_groups.window_frame_end_idx, 1); -// assert!(!window_frame_groups.reached_end); -// assert_eq!(window_frame_groups.group_start_indices.len(), 1); -// -// for idx in 0..test_data.num_rows { -// let start = window_frame_groups -// .calculate_index_of_group::( -// &test_data.arrays, -// idx, -// DELTA, -// test_data.num_rows, -// ) -// .unwrap(); -// assert_eq!(start, test_data.group_indices[idx]); -// } -// } -// } From 79ee7097cad5581e948df3b4b499a881e3d6b99d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 6 Feb 2023 14:25:22 +0300 Subject: [PATCH 09/18] simplifications --- datafusion/common/src/utils.rs | 7 +-- .../src/window/window_frame_state.rs | 63 ++++++++++--------- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 42924348b1ba..a4eccef04c08 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -158,10 +158,9 @@ where Ok(low) } -/// This function searches for a tuple of given values (`target`) among a slice of -/// the given rows (`item_columns`) via a linear scan. The slice starts at the index -/// `low` and ends at the index `high`. The boolean-valued function `compare_fn` -/// specifies the stopping criterion. +/// This function implements a linear search algorithm to find the +/// first row in a table that differs from the starting row e.g +/// row at idx `low` pub fn search_till_change( item_columns: &[ArrayRef], mut low: usize, diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 33850742e4d3..916c7c86b7c3 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -101,7 +101,6 @@ impl<'a> WindowFrameContext<'a> { sort_options, length, idx, - last_range, ), } } @@ -347,6 +346,16 @@ impl WindowFrameStateRange { // scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { + // Stores the tuple where first element is the row group contains + // second value is the index where group ends + // For instance, + // [ + // [1,1], + // [1,1], + // [2,1], + // [2,1], + // ] + // would produce VecDeque::from([([1,1], 2), ([2,1], 4)]); group_start_indices: VecDeque<(Vec, usize)>, // Keeps the groups index that row index belongs cur_group_idx: usize, @@ -360,7 +369,6 @@ impl WindowFrameStateGroups { sort_options: &[SortOptions], length: usize, idx: usize, - last_range: &Range, ) -> Result> { if range_columns.is_empty() { return Err(DataFusionError::Execution( @@ -378,7 +386,6 @@ impl WindowFrameStateGroups { sort_options, idx, Some(n), - last_range, length, )? } @@ -392,7 +399,6 @@ impl WindowFrameStateGroups { sort_options, idx, None, - last_range, length, )? } @@ -403,7 +409,6 @@ impl WindowFrameStateGroups { sort_options, idx, Some(n), - last_range, length, )?, }; @@ -414,7 +419,6 @@ impl WindowFrameStateGroups { sort_options, idx, Some(n), - last_range, length, )?, WindowFrameBound::CurrentRow => { @@ -426,7 +430,6 @@ impl WindowFrameStateGroups { sort_options, idx, None, - last_range, length, )? } @@ -441,7 +444,6 @@ impl WindowFrameStateGroups { sort_options, idx, Some(n), - last_range, length, )? } @@ -452,21 +454,16 @@ impl WindowFrameStateGroups { /// This function does the heavy lifting when finding range boundaries. It is meant to be /// called twice, in succession, to get window frame start and end indices (with `SIDE` - /// supplied as true and false, respectively). + /// supplied as true and false, respectively) `PRECEDING` determines sign of the delta ( + /// where true represents negative, false represents positive) fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], _sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, - last_range: &Range, length: usize, ) -> Result { - let search_start = if SIDE { - last_range.start - } else { - last_range.end - }; let delta = if let Some(delta) = delta { match delta { ScalarValue::UInt64(Some(val)) => Ok(*val as usize), @@ -475,18 +472,19 @@ impl WindowFrameStateGroups { } else { Ok(0) }?; - let mut change_idx = search_start; + let mut group_start = 0; let last_group = self.group_start_indices.back(); - if let Some((_last_group, bound_idx)) = last_group { - // Start searching from change point from last boundary - change_idx = *bound_idx; + if let Some((_last_group, group_end)) = last_group { + // Start searching from last group boundary + group_start = *group_end; } // Progress groups until idx is inside a group - while idx > change_idx { - let group_row = get_row_at_idx(range_columns, change_idx)?; - change_idx = search_till_change(range_columns, change_idx, length)?; - self.group_start_indices.push_back((group_row, change_idx)); + while idx > group_start { + let group_row = get_row_at_idx(range_columns, group_start)?; + let group_end = search_till_change(range_columns, group_start, length)?; + self.group_start_indices.push_back((group_row, group_end)); + group_start = group_end; } // Update the group index `idx` belongs. @@ -497,6 +495,7 @@ impl WindowFrameStateGroups { break; } } + // Group idx of the frame boundary let group_idx = if PRECEDING { if self.cur_group_idx > delta { self.cur_group_idx - delta @@ -508,29 +507,37 @@ impl WindowFrameStateGroups { }; // Expand group_start_indices until it includes at least group_idx - while self.group_start_indices.len() <= group_idx && change_idx < length { - // println!("change idx: {:?}, idx: {:?}, self.cur_group_idx: {:?}", change_idx, idx, self.cur_group_idx); - let group_row = get_row_at_idx(range_columns, change_idx)?; - change_idx = search_till_change(range_columns, change_idx, length)?; - self.group_start_indices.push_back((group_row, change_idx)); + while self.group_start_indices.len() <= group_idx && group_start < length { + let group_row = get_row_at_idx(range_columns, group_start)?; + let group_end = search_till_change(range_columns, group_start, length)?; + self.group_start_indices.push_back((group_row, group_end)); + group_start = group_end; } + + // calculates index of the group boundary Ok(match (SIDE, PRECEDING) { + // window frame start (true, _) => { let group_idx = min(group_idx, self.group_start_indices.len()); if group_idx > 0 { + // window frame start is: end boundary of previous group self.group_start_indices[group_idx - 1].1 } else { + // If previous group is out of table, window frame start is 0 0 } } + // window frame end, PRECEDING n (false, true) => { if self.cur_group_idx >= delta { let group_idx = self.cur_group_idx - delta; self.group_start_indices[group_idx].1 } else { + // group is out of table hence end of window frame is 0 0 } } + // window frame end, FOLLOWING n (false, false) => { let group_idx = min( self.cur_group_idx + delta, From 795037799a12fa88552ae50e429495a3aa7297fc Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 9 Feb 2023 15:07:23 +0300 Subject: [PATCH 10/18] Add unit tests --- datafusion/common/src/utils.rs | 19 --- .../src/window/window_frame_state.rs | 123 +++++++++++++++++- 2 files changed, 119 insertions(+), 23 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index a4eccef04c08..3c073015343c 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -158,25 +158,6 @@ where Ok(low) } -/// This function implements a linear search algorithm to find the -/// first row in a table that differs from the starting row e.g -/// row at idx `low` -pub fn search_till_change( - item_columns: &[ArrayRef], - mut low: usize, - high: usize, -) -> Result { - let start_row = get_row_at_idx(item_columns, low)?; - while low < high { - let val = get_row_at_idx(item_columns, low)?; - if !start_row.eq(&val) { - break; - } - low += 1; - } - Ok(low) -} - #[cfg(test)] mod tests { use arrow::array::Float64Array; diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 916c7c86b7c3..744477078150 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -21,7 +21,7 @@ use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; use datafusion_common::utils::{ - compare_rows, get_row_at_idx, search_in_slice, search_till_change, + compare_rows, get_row_at_idx, linear_search, search_in_slice, }; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; @@ -459,7 +459,7 @@ impl WindowFrameStateGroups { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], - _sort_options: &[SortOptions], + sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, length: usize, @@ -482,7 +482,9 @@ impl WindowFrameStateGroups { // Progress groups until idx is inside a group while idx > group_start { let group_row = get_row_at_idx(range_columns, group_start)?; - let group_end = search_till_change(range_columns, group_start, length)?; + // find end boundary of of the group (search right boundary) + let group_end = + linear_search::(range_columns, &group_row, sort_options)?; self.group_start_indices.push_back((group_row, group_end)); group_start = group_end; } @@ -509,7 +511,9 @@ impl WindowFrameStateGroups { // Expand group_start_indices until it includes at least group_idx while self.group_start_indices.len() <= group_idx && group_start < length { let group_row = get_row_at_idx(range_columns, group_start)?; - let group_end = search_till_change(range_columns, group_start, length)?; + // find end boundary of of the group (search right boundary) + let group_end = + linear_search::(range_columns, &group_row, sort_options)?; self.group_start_indices.push_back((group_row, group_end)); group_start = group_end; } @@ -548,3 +552,114 @@ impl WindowFrameStateGroups { }) } } + +#[cfg(test)] +mod tests { + use crate::window::window_frame_state::WindowFrameStateGroups; + use arrow::array::{ArrayRef, Float64Array}; + use arrow_schema::SortOptions; + use datafusion_common::from_slice::FromSlice; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use std::ops::Range; + use std::sync::Arc; + + fn get_test_data() -> (Vec, Vec) { + let range_columns: Vec = vec![Arc::new(Float64Array::from_slice([ + 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11., + ]))]; + let sort_options = vec![SortOptions { + descending: false, + nulls_first: false, + }]; + + (range_columns, sort_options) + } + + fn assert_expected( + expected_results: Vec<(Range, usize)>, + window_frame: &Arc, + ) -> Result<()> { + let mut window_frame_groups = WindowFrameStateGroups::default(); + let (range_columns, sort_options) = get_test_data(); + let n_row = range_columns[0].len(); + for (idx, (expected_range, expected_group_idx)) in + expected_results.into_iter().enumerate() + { + let range = window_frame_groups.calculate_range( + window_frame, + &range_columns, + &sort_options, + n_row, + idx, + )?; + assert_eq!(range, expected_range); + assert_eq!(window_frame_groups.cur_group_idx, expected_group_idx); + } + Ok(()) + } + + #[test] + fn test_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), + }); + let expected_results = vec![ + (Range { start: 0, end: 2 }, 0), + (Range { start: 0, end: 4 }, 1), + (Range { start: 1, end: 5 }, 2), + (Range { start: 1, end: 5 }, 2), + (Range { start: 2, end: 8 }, 3), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 5, end: 9 }, 5), + ]; + assert_expected(expected_results, &window_frame) + } + + #[test] + fn test_window_frame_group_boundaries_both_following() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }); + let expected_results = vec![ + (Range:: { start: 1, end: 4 }, 0), + (Range:: { start: 2, end: 5 }, 1), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 5, end: 9 }, 3), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 9, end: 9 }, 5), + ]; + assert_expected(expected_results, &window_frame) + } + + #[test] + fn test_window_frame_group_boundaries_both_preceding() -> Result<()> { + let window_frame = Arc::new(WindowFrame { + units: WindowFrameUnits::Groups, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + }); + let expected_results = vec![ + (Range:: { start: 0, end: 0 }, 0), + (Range:: { start: 0, end: 1 }, 1), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 1, end: 4 }, 3), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 4, end: 8 }, 5), + ]; + assert_expected(expected_results, &window_frame) + } +} From 40aa24d4d18ae93960a42c8c52762722e2fc5c29 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sun, 12 Feb 2023 19:29:50 -0600 Subject: [PATCH 11/18] Remove sort options from GROUPS calculations, various code simplifications and comment clarifications --- .../src/window/window_frame_state.rs | 168 +++++++++--------- 1 file changed, 83 insertions(+), 85 deletions(-) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 744477078150..69eb31d90c75 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! This module provides utilities for window frame index calculations depending on the window frame mode: -//! RANGE, ROWS, GROUPS. +//! This module provides utilities for window frame index calculations +//! depending on the window frame mode: RANGE, ROWS, GROUPS. use arrow::array::ArrayRef; use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{ - compare_rows, get_row_at_idx, linear_search, search_in_slice, -}; +use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; @@ -77,8 +75,9 @@ impl<'a> WindowFrameContext<'a> { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) } - // sort_options is used in RANGE mode calculations because the ordering and the position of the nulls - // have impact on the range calculations and comparison of the rows. + // Sort options is used in RANGE mode calculations because the + // ordering or position of NULLs impact range calculations and + // comparison of rows. WindowFrameContext::Range { window_frame, ref mut state, @@ -90,18 +89,13 @@ impl<'a> WindowFrameContext<'a> { idx, last_range, ), - // sort_options is not used in GROUPS mode calculations as the inequality of two rows is the indicator - // of a group change, and the ordering and the position of the nulls do not have impact on inequality. + // Sort options is not used in GROUPS mode calculations as the + // inequality of two rows indicates a group change, and ordering + // or position of NULLs do not impact inequality. WindowFrameContext::Groups { window_frame, ref mut state, - } => state.calculate_range( - window_frame, - range_columns, - sort_options, - length, - idx, - ), + } => state.calculate_range(window_frame, range_columns, length, idx), } } @@ -281,7 +275,11 @@ impl WindowFrameStateRange { let end_range = if let Some(delta) = delta { let is_descending: bool = sort_options .first() - .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? + .ok_or_else(|| { + DataFusionError::Internal( + "Sort options unexpectedly absent in a window frame".to_string(), + ) + })? .descending; current_row_values @@ -291,7 +289,7 @@ impl WindowFrameStateRange { return Ok(value.clone()); } if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows + // TODO: Handle positive overflows. value.add(delta) } else if value.is_unsigned() && value < delta { // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. @@ -299,7 +297,7 @@ impl WindowFrameStateRange { // change the following statement to use that. value.sub(value) } else { - // TODO: Handle negative overflows + // TODO: Handle negative overflows. value.sub(delta) } }) @@ -346,19 +344,12 @@ impl WindowFrameStateRange { // scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { - // Stores the tuple where first element is the row group contains - // second value is the index where group ends - // For instance, - // [ - // [1,1], - // [1,1], - // [2,1], - // [2,1], - // ] - // would produce VecDeque::from([([1,1], 2), ([2,1], 4)]); + /// A tuple containing group values and the row index where the group ends. + /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to + /// [([1, 1], 2), ([2, 1], 4), ...]. group_start_indices: VecDeque<(Vec, usize)>, - // Keeps the groups index that row index belongs - cur_group_idx: usize, + /// The group index to which the row index belongs. + current_group_idx: usize, } impl WindowFrameStateGroups { @@ -366,7 +357,6 @@ impl WindowFrameStateGroups { &mut self, window_frame: &Arc, range_columns: &[ArrayRef], - sort_options: &[SortOptions], length: usize, idx: usize, ) -> Result> { @@ -383,7 +373,6 @@ impl WindowFrameStateGroups { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), length, @@ -396,7 +385,6 @@ impl WindowFrameStateGroups { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, None, length, @@ -406,7 +394,6 @@ impl WindowFrameStateGroups { WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), length, @@ -416,7 +403,6 @@ impl WindowFrameStateGroups { WindowFrameBound::Preceding(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), length, @@ -427,7 +413,6 @@ impl WindowFrameStateGroups { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, None, length, @@ -441,7 +426,6 @@ impl WindowFrameStateGroups { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), length, @@ -454,97 +438,109 @@ impl WindowFrameStateGroups { /// This function does the heavy lifting when finding range boundaries. It is meant to be /// called twice, in succession, to get window frame start and end indices (with `SIDE` - /// supplied as true and false, respectively) `PRECEDING` determines sign of the delta ( - /// where true represents negative, false represents positive) - fn calculate_index_of_row( + /// supplied as true and false, respectively). Generic argument `SEARCH_SIDE` determines + /// the sign of `delta` (where true/false represents negative/positive respectively). + fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, length: usize, ) -> Result { let delta = if let Some(delta) = delta { - match delta { - ScalarValue::UInt64(Some(val)) => Ok(*val as usize), - _ => Err(DataFusionError::Execution("expects uint64".to_string())), + if let ScalarValue::UInt64(Some(value)) = delta { + *value as usize + } else { + return Err(DataFusionError::Internal( + "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame" + .to_string(), + )); } } else { - Ok(0) - }?; + 0 + }; let mut group_start = 0; let last_group = self.group_start_indices.back(); - if let Some((_last_group, group_end)) = last_group { - // Start searching from last group boundary + if let Some((_, group_end)) = last_group { + // Start searching from the last group boundary: group_start = *group_end; } - // Progress groups until idx is inside a group + // Advance groups until `idx` is inside a group: while idx > group_start { let group_row = get_row_at_idx(range_columns, group_start)?; - // find end boundary of of the group (search right boundary) - let group_end = - linear_search::(range_columns, &group_row, sort_options)?; + // Find end boundary of the group (search right boundary): + let group_end = search_in_slice( + range_columns, + &group_row, + check_equality, + group_start, + length, + )?; self.group_start_indices.push_back((group_row, group_end)); group_start = group_end; } - // Update the group index `idx` belongs. - while self.cur_group_idx < self.group_start_indices.len() { - if idx >= self.group_start_indices[self.cur_group_idx].1 { - self.cur_group_idx += 1; - } else { - break; - } + // Update the group index `idx` belongs to: + while self.current_group_idx < self.group_start_indices.len() + && idx >= self.group_start_indices[self.current_group_idx].1 + { + self.current_group_idx += 1; } - // Group idx of the frame boundary - let group_idx = if PRECEDING { - if self.cur_group_idx > delta { - self.cur_group_idx - delta + + // Find the group index of the frame boundary: + let group_idx = if SEARCH_SIDE { + if self.current_group_idx > delta { + self.current_group_idx - delta } else { 0 } } else { - self.cur_group_idx + delta + self.current_group_idx + delta }; - // Expand group_start_indices until it includes at least group_idx + // Extend `group_start_indices` until it includes at least `group_idx`: while self.group_start_indices.len() <= group_idx && group_start < length { let group_row = get_row_at_idx(range_columns, group_start)?; - // find end boundary of of the group (search right boundary) - let group_end = - linear_search::(range_columns, &group_row, sort_options)?; + // Find end boundary of the group (search right boundary): + let group_end = search_in_slice( + range_columns, + &group_row, + check_equality, + group_start, + length, + )?; self.group_start_indices.push_back((group_row, group_end)); group_start = group_end; } - // calculates index of the group boundary - Ok(match (SIDE, PRECEDING) { - // window frame start + // Calculate index of the group boundary: + Ok(match (SIDE, SEARCH_SIDE) { + // Window frame start: (true, _) => { let group_idx = min(group_idx, self.group_start_indices.len()); if group_idx > 0 { - // window frame start is: end boundary of previous group + // Normally, start at the boundary of the previous group. self.group_start_indices[group_idx - 1].1 } else { - // If previous group is out of table, window frame start is 0 + // If previous group is out of the table, start at zero. 0 } } - // window frame end, PRECEDING n + // Window frame end, PRECEDING n (false, true) => { - if self.cur_group_idx >= delta { - let group_idx = self.cur_group_idx - delta; + if self.current_group_idx >= delta { + let group_idx = self.current_group_idx - delta; self.group_start_indices[group_idx].1 } else { - // group is out of table hence end of window frame is 0 + // Group is out of the table, therefore end at zero. 0 } } - // window frame end, FOLLOWING n + // Window frame end, FOLLOWING n (false, false) => { let group_idx = min( - self.cur_group_idx + delta, + self.current_group_idx + delta, self.group_start_indices.len() - 1, ); self.group_start_indices[group_idx].1 @@ -553,14 +549,17 @@ impl WindowFrameStateGroups { } } +fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result { + Ok(current == target) +} + #[cfg(test)] mod tests { use crate::window::window_frame_state::WindowFrameStateGroups; use arrow::array::{ArrayRef, Float64Array}; use arrow_schema::SortOptions; use datafusion_common::from_slice::FromSlice; - use datafusion_common::Result; - use datafusion_common::ScalarValue; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::ops::Range; use std::sync::Arc; @@ -582,7 +581,7 @@ mod tests { window_frame: &Arc, ) -> Result<()> { let mut window_frame_groups = WindowFrameStateGroups::default(); - let (range_columns, sort_options) = get_test_data(); + let (range_columns, _) = get_test_data(); let n_row = range_columns[0].len(); for (idx, (expected_range, expected_group_idx)) in expected_results.into_iter().enumerate() @@ -590,12 +589,11 @@ mod tests { let range = window_frame_groups.calculate_range( window_frame, &range_columns, - &sort_options, n_row, idx, )?; assert_eq!(range, expected_range); - assert_eq!(window_frame_groups.cur_group_idx, expected_group_idx); + assert_eq!(window_frame_groups.current_group_idx, expected_group_idx); } Ok(()) } From 796f5d044dbf7ff45fb13b6208e8d89153ab125c Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sun, 12 Feb 2023 21:15:03 -0600 Subject: [PATCH 12/18] New TODOs to fix --- .../src/window/window_frame_state.rs | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 69eb31d90c75..33a55e8fce51 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -34,11 +34,13 @@ use std::sync::Arc; pub enum WindowFrameContext<'a> { // ROWS-frames are inherently stateless: Rows(&'a Arc), + // TODO: Fix this comment, briefly mention what the state is. // RANGE-frames will soon have a stateful implementation that is more efficient than a stateless one: Range { window_frame: &'a Arc, state: WindowFrameStateRange, }, + // TODO: Fix this comment, briefly mention what the state is. // GROUPS-frames have a stateful implementation that is more efficient than a stateless one: Groups { window_frame: &'a Arc, @@ -159,6 +161,8 @@ impl<'a> WindowFrameContext<'a> { } } +// TODO: Fix this struct and the comment when we move "where-we-left-off" +// information from arguments/return values into the state. /// This structure encapsulates all the state information we require as we /// scan ranges of data while processing window frames. Currently we calculate /// things from scratch every time, but we will make this incremental in the future. @@ -193,6 +197,9 @@ impl WindowFrameStateRange { } } WindowFrameBound::CurrentRow => { + // TODO: Is this check at the right place? Is this being empty only a problem in this + // match arm? Or is it even a problem here? Let's add a test to exercise this + // if it doesn't exist, and if necessary, move this check to its right place. if range_columns.is_empty() { 0 } else { @@ -227,6 +234,9 @@ impl WindowFrameStateRange { length, )?, WindowFrameBound::CurrentRow => { + // TODO: Is this check at the right place? Is this being empty only a problem in this + // match arm? Or is it even a problem here? Let's add a test to exercise this + // if it doesn't exist, and if necessary, move this check to its right place. if range_columns.is_empty() { length } else { @@ -360,6 +370,8 @@ impl WindowFrameStateGroups { length: usize, idx: usize, ) -> Result> { + // TODO: This check contradicts with the same check in the handling of CurrentRow + // below. See my comment there, and fix this in the context of that. if range_columns.is_empty() { return Err(DataFusionError::Execution( "GROUPS mode requires an ORDER BY clause".to_string(), @@ -380,6 +392,9 @@ impl WindowFrameStateGroups { } } WindowFrameBound::CurrentRow => { + // TODO: Is this check at the right place? Is this being empty only a problem in this + // match arm? Or is it even a problem here? Let's add a test to exercise this + // if it doesn't exist, and if necessary, move this check to its right place. if range_columns.is_empty() { 0 } else { @@ -408,6 +423,9 @@ impl WindowFrameStateGroups { length, )?, WindowFrameBound::CurrentRow => { + // TODO: Is this check at the right place? Is this being empty only a problem in this + // match arm? Or is it even a problem here? Let's add a test to exercise this + // if it doesn't exist, and if necessary, move this check to its right place. if range_columns.is_empty() { length } else { @@ -516,6 +534,9 @@ impl WindowFrameStateGroups { // Calculate index of the group boundary: Ok(match (SIDE, SEARCH_SIDE) { + // TODO: Is it normal that window frame start and end are asymmetric? You have + // PRECEDING and FOLLOWING cases separate for end, but not for start. + // Seems like this is OK, but let's make sure. // Window frame start: (true, _) => { let group_idx = min(group_idx, self.group_start_indices.len()); From c7817550bd5ec53e4bffbe6bd251484ad507c435 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 13 Feb 2023 10:56:32 +0300 Subject: [PATCH 13/18] Address reviews --- datafusion/core/tests/sql/window.rs | 13 ++ .../physical-expr/src/window/built_in.rs | 32 ++-- .../physical-expr/src/window/window_expr.rs | 27 +--- .../src/window/window_frame_state.rs | 143 ++++++++---------- 4 files changed, 97 insertions(+), 118 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 21a6062b8cd7..a92a813089c3 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1477,6 +1477,19 @@ async fn window_frame_creation() -> Result<()> { "External error: Internal error: Operator - is not implemented for types UInt32(1) and Utf8(\"1 DAY\"). This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker" ); + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER(groups BETWEEN current row and UNBOUNDED FOLLOWING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_contains!( + results.err().unwrap().to_string(), + "External error: Execution error: GROUPS mode requires an ORDER BY clause" + ); + Ok(()) } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index b53164f66944..70ddb2c7671a 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -104,17 +104,15 @@ impl WindowExpr for BuiltInWindowExpr { let mut row_wise_results = vec![]; let (values, order_bys) = self.get_values_orderbys(batch)?; - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let range = Range { start: 0, end: 0 }; + let mut window_frame_ctx = WindowFrameContext::new( + &self.window_frame, + sort_options, + Range { start: 0, end: 0 }, + ); // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { - let range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - &range, - )?; + let range = + window_frame_ctx.calculate_range(&order_bys, num_rows, idx)?; let value = evaluator.evaluate_inside_range(&values, &range)?; row_wise_results.push(value); } @@ -168,7 +166,13 @@ impl WindowExpr for BuiltInWindowExpr { // We iterate on each row to perform a running calculation. let record_batch = &partition_batch_state.record_batch; let num_rows = record_batch.num_rows(); - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let last_range = state.window_frame_range.clone(); + let mut window_frame_ctx = WindowFrameContext::new( + &self.window_frame, + sort_options.clone(), + // Start search from the last range + last_range, + ); let sort_partition_points = if evaluator.include_rank() { let columns = self.sort_columns(record_batch)?; self.evaluate_partition_points(num_rows, &columns)? @@ -179,13 +183,7 @@ impl WindowExpr for BuiltInWindowExpr { let mut last_range = state.window_frame_range.clone(); for idx in state.last_calculated_index..num_rows { state.window_frame_range = if self.expr.uses_window_frame() { - window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - &state.window_frame_range, - ) + window_frame_ctx.calculate_range(&order_bys, num_rows, idx) } else { evaluator.get_range(state, num_rows) }?; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index a38d5de542c0..ced1f0b1aa69 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -162,18 +162,10 @@ pub trait AggregateWindowExpr: WindowExpr { /// Evaluates the window function against the batch. fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { - let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame()); let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; let mut idx = 0; - self.get_result_column( - &mut accumulator, - batch, - &mut window_frame_ctx, - &mut last_range, - &mut idx, - false, - ) + self.get_result_column(&mut accumulator, batch, &mut last_range, &mut idx, false) } /// Statefully evaluates the window function against the batch. Maintains @@ -207,11 +199,9 @@ pub trait AggregateWindowExpr: WindowExpr { let mut state = &mut window_state.state; let record_batch = &partition_batch_state.record_batch; - let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame()); let out_col = self.get_result_column( accumulator, record_batch, - &mut window_frame_ctx, &mut state.window_frame_range, &mut state.last_calculated_index, !partition_batch_state.is_end, @@ -230,7 +220,6 @@ pub trait AggregateWindowExpr: WindowExpr { &self, accumulator: &mut Box, record_batch: &RecordBatch, - window_frame_ctx: &mut WindowFrameContext, last_range: &mut Range, idx: &mut usize, not_end: bool, @@ -240,15 +229,15 @@ pub trait AggregateWindowExpr: WindowExpr { let length = values[0].len(); let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); + let mut window_frame_ctx = WindowFrameContext::new( + self.get_window_frame(), + sort_options, + // Start search from the last range + last_range.clone(), + ); let mut row_wise_results: Vec = vec![]; while *idx < length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - *idx, - last_range, - )?; + let cur_range = window_frame_ctx.calculate_range(&order_bys, length, *idx)?; // Exit if the range extends all the way: if cur_range.end == length && not_end { break; diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 33a55e8fce51..5276b027dd77 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -34,14 +34,14 @@ use std::sync::Arc; pub enum WindowFrameContext<'a> { // ROWS-frames are inherently stateless: Rows(&'a Arc), - // TODO: Fix this comment, briefly mention what the state is. - // RANGE-frames will soon have a stateful implementation that is more efficient than a stateless one: + // RANGE-frames store window frame to calculate window frame boundaries + // In `state`, it keeps track of `last_range` calculated to increase search speed. Range { window_frame: &'a Arc, state: WindowFrameStateRange, }, - // TODO: Fix this comment, briefly mention what the state is. - // GROUPS-frames have a stateful implementation that is more efficient than a stateless one: + // GROUPS-frames store window frame to calculate window frame boundaries + // In `state`, we store the boundaries of each group from the start of the table. Groups { window_frame: &'a Arc, state: WindowFrameStateGroups, @@ -50,12 +50,16 @@ pub enum WindowFrameContext<'a> { impl<'a> WindowFrameContext<'a> { /// Create a new default state for the given window frame. - pub fn new(window_frame: &'a Arc) -> Self { + pub fn new( + window_frame: &'a Arc, + sort_options: Vec, + last_range: Range, + ) -> Self { match window_frame.units { WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame), WindowFrameUnits::Range => WindowFrameContext::Range { window_frame, - state: WindowFrameStateRange::default(), + state: WindowFrameStateRange::new(sort_options, last_range), }, WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, @@ -68,10 +72,8 @@ impl<'a> WindowFrameContext<'a> { pub fn calculate_range( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], length: usize, idx: usize, - last_range: &Range, ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { @@ -83,14 +85,7 @@ impl<'a> WindowFrameContext<'a> { WindowFrameContext::Range { window_frame, ref mut state, - } => state.calculate_range( - window_frame, - range_columns, - sort_options, - length, - idx, - last_range, - ), + } => state.calculate_range(window_frame, range_columns, length, idx), // Sort options is not used in GROUPS mode calculations as the // inequality of two rows indicates a group change, and ordering // or position of NULLs do not impact inequality. @@ -161,24 +156,35 @@ impl<'a> WindowFrameContext<'a> { } } -// TODO: Fix this struct and the comment when we move "where-we-left-off" -// information from arguments/return values into the state. /// This structure encapsulates all the state information we require as we -/// scan ranges of data while processing window frames. Currently we calculate -/// things from scratch every time, but we will make this incremental in the future. +/// scan ranges of data while processing window frames. +/// `last_range` keeps the range calculated at the last search. Since we know that range only can progress forward +/// at the next search we start from `last_range`. This makes linear search amortized constant. +/// `sort_options` keeps the ordering of the columns in the ORDER BY clause. This information is used to calculate +/// range boundary, #[derive(Debug, Default)] -pub struct WindowFrameStateRange {} +pub struct WindowFrameStateRange { + last_range: Range, + sort_options: Vec, +} impl WindowFrameStateRange { + /// Creates new struct for range calculation + fn new(sort_options: Vec, last_range: Range) -> Self { + Self { + // Keeps the search range we last calculated + last_range, + sort_options, + } + } + /// This function calculates beginning/ending indices for the frame of the current row. fn calculate_range( &mut self, window_frame: &Arc, range_columns: &[ArrayRef], - sort_options: &[SortOptions], length: usize, idx: usize, - last_range: &Range, ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { @@ -188,27 +194,25 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )? } } WindowFrameBound::CurrentRow => { - // TODO: Is this check at the right place? Is this being empty only a problem in this - // match arm? Or is it even a problem here? Let's add a test to exercise this - // if it doesn't exist, and if necessary, move this check to its right place. + // If RANGE queries contain `CURRENT ROW` clause in the window frame start, + // when there is no ordering (e.g `range_columns` is empty) + // Window frame start is treated as UNBOUNDED PRECEDING. As an example + // OVER(RANGE BETWEEN CURRENT ROW and UNBOUNDED FOLLOWING) is treated as + // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) if range_columns.is_empty() { 0 } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, None, - last_range, length, )? } @@ -216,10 +220,8 @@ impl WindowFrameStateRange { WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )?, }; @@ -227,25 +229,23 @@ impl WindowFrameStateRange { WindowFrameBound::Preceding(ref n) => self .calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )?, WindowFrameBound::CurrentRow => { - // TODO: Is this check at the right place? Is this being empty only a problem in this - // match arm? Or is it even a problem here? Let's add a test to exercise this - // if it doesn't exist, and if necessary, move this check to its right place. + // If RANGE queries contain `CURRENT ROW` clause in the window frame end, + // when there is no ordering (e.g `range_columns` is empty) + // Window frame end is treated as UNBOUNDED FOLLOWING. As an example + // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and CURRENT ROW) is treated as + // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) if range_columns.is_empty() { length } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, None, - last_range, length, )? } @@ -257,15 +257,16 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, - sort_options, idx, Some(n), - last_range, length, )? } } }; + // Store last calculated range, to start where we left of in the next iteration + self.last_range.start = start; + self.last_range.end = end; Ok(Range { start, end }) } @@ -275,15 +276,14 @@ impl WindowFrameStateRange { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], - sort_options: &[SortOptions], idx: usize, delta: Option<&ScalarValue>, - last_range: &Range, length: usize, ) -> Result { let current_row_values = get_row_at_idx(range_columns, idx)?; let end_range = if let Some(delta) = delta { - let is_descending: bool = sort_options + let is_descending: bool = self + .sort_options .first() .ok_or_else(|| { DataFusionError::Internal( @@ -316,12 +316,12 @@ impl WindowFrameStateRange { current_row_values }; let search_start = if SIDE { - last_range.start + self.last_range.start } else { - last_range.end + self.last_range.end }; let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare_rows(current, target, sort_options)?; + let cmp = compare_rows(current, target, &self.sort_options)?; Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) }; search_in_slice(range_columns, &end_range, compare_fn, search_start, length) @@ -370,8 +370,8 @@ impl WindowFrameStateGroups { length: usize, idx: usize, ) -> Result> { - // TODO: This check contradicts with the same check in the handling of CurrentRow - // below. See my comment there, and fix this in the context of that. + // Groups mode should have an ordering + // e.g `range_columns` shouldn't be empty if range_columns.is_empty() { return Err(DataFusionError::Execution( "GROUPS mode requires an ORDER BY clause".to_string(), @@ -391,21 +391,12 @@ impl WindowFrameStateGroups { )? } } - WindowFrameBound::CurrentRow => { - // TODO: Is this check at the right place? Is this being empty only a problem in this - // match arm? Or is it even a problem here? Let's add a test to exercise this - // if it doesn't exist, and if necessary, move this check to its right place. - if range_columns.is_empty() { - 0 - } else { - self.calculate_index_of_row::( - range_columns, - idx, - None, - length, - )? - } - } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, @@ -422,21 +413,12 @@ impl WindowFrameStateGroups { Some(n), length, )?, - WindowFrameBound::CurrentRow => { - // TODO: Is this check at the right place? Is this being empty only a problem in this - // match arm? Or is it even a problem here? Let's add a test to exercise this - // if it doesn't exist, and if necessary, move this check to its right place. - if range_columns.is_empty() { - length - } else { - self.calculate_index_of_row::( - range_columns, - idx, - None, - length, - )? - } - } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => { if n.is_null() { // UNBOUNDED FOLLOWING @@ -534,9 +516,6 @@ impl WindowFrameStateGroups { // Calculate index of the group boundary: Ok(match (SIDE, SEARCH_SIDE) { - // TODO: Is it normal that window frame start and end are asymmetric? You have - // PRECEDING and FOLLOWING cases separate for end, but not for start. - // Seems like this is OK, but let's make sure. // Window frame start: (true, _) => { let group_idx = min(group_idx, self.group_start_indices.len()); From a8be4969df701fd61949157497fe0b3f302a3404 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 13 Feb 2023 14:09:48 +0300 Subject: [PATCH 14/18] Fix error --- datafusion/core/tests/sql/window.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 6ab6d8d38dcc..b275a0784825 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1487,7 +1487,7 @@ async fn window_frame_creation() -> Result<()> { let results = df.collect().await; assert_contains!( results.err().unwrap().to_string(), - "External error: Execution error: GROUPS mode requires an ORDER BY clause" + "Execution error: GROUPS mode requires an ORDER BY clause" ); Ok(()) From c53ac07ef0a14452388fe0402f0ac72ce23d17d6 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 13 Feb 2023 21:23:54 +0300 Subject: [PATCH 15/18] Prehandle range current row and unbounded following case --- datafusion/core/tests/sql/window.rs | 23 +++++----- .../src/window/window_frame_state.rs | 46 +++++-------------- datafusion/sql/src/expr/function.rs | 19 ++++++-- 3 files changed, 39 insertions(+), 49 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b275a0784825..bf8e6f9583a1 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1035,21 +1035,22 @@ async fn window_frame_ranges_unbounded_preceding_following() -> Result<()> { register_aggregate_csv(&ctx).await?; let sql = "SELECT \ SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as sum1, \ - COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1 \ + COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1, \ + COUNT(c1) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as cnt2 FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+------+", - "| sum1 | cnt1 |", - "+------+------+", - "| 285 | 100 |", - "| 123 | 63 |", - "| 285 | 100 |", - "| 123 | 63 |", - "| 123 | 63 |", - "+------+------+", + "+------+------+------+", + "| sum1 | cnt1 | cnt2 |", + "+------+------+------+", + "| 285 | 100 | 100 |", + "| 123 | 63 | 100 |", + "| 285 | 100 | 100 |", + "| 123 | 63 | 100 |", + "| 123 | 63 | 100 |", + "+------+------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -1480,7 +1481,7 @@ async fn window_frame_creation() -> Result<()> { let df = ctx .sql( "SELECT - COUNT(c1) OVER(groups BETWEEN current row and UNBOUNDED FOLLOWING) + COUNT(c1) OVER(GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM aggregate_test_100;", ) .await?; diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 5276b027dd77..b282a8103267 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -200,23 +200,12 @@ impl WindowFrameStateRange { )? } } - WindowFrameBound::CurrentRow => { - // If RANGE queries contain `CURRENT ROW` clause in the window frame start, - // when there is no ordering (e.g `range_columns` is empty) - // Window frame start is treated as UNBOUNDED PRECEDING. As an example - // OVER(RANGE BETWEEN CURRENT ROW and UNBOUNDED FOLLOWING) is treated as - // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) - if range_columns.is_empty() { - 0 - } else { - self.calculate_index_of_row::( - range_columns, - idx, - None, - length, - )? - } - } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, @@ -233,23 +222,12 @@ impl WindowFrameStateRange { Some(n), length, )?, - WindowFrameBound::CurrentRow => { - // If RANGE queries contain `CURRENT ROW` clause in the window frame end, - // when there is no ordering (e.g `range_columns` is empty) - // Window frame end is treated as UNBOUNDED FOLLOWING. As an example - // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and CURRENT ROW) is treated as - // OVER(RANGE BETWEEN UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) - if range_columns.is_empty() { - length - } else { - self.calculate_index_of_row::( - range_columns, - idx, - None, - length, - )? - } - } + WindowFrameBound::CurrentRow => self.calculate_index_of_row::( + range_columns, + idx, + None, + length, + )?, WindowFrameBound::Following(ref n) => { if n.is_null() { // UNBOUNDED FOLLOWING diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1845d59472c6..70015bbaea51 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -17,11 +17,11 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFrameUnits, WindowFunction, + WindowFrameBound, WindowFrameUnits, WindowFunction, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, @@ -69,8 +69,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if WindowFrameUnits::Range == window_frame.units && order_by.len() != 1 { - Err(DataFusionError::Plan(format!( - "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) + // Construct equivalent window frame. Only in below expressions for RANGE ORDER BY is not a requirement + // In the downstream we will assume that RANGE contains ORDER BY clause. + if (window_frame.start_bound.is_unbounded() || window_frame.start_bound == WindowFrameBound::CurrentRow) + && (window_frame.end_bound == WindowFrameBound::CurrentRow || window_frame.end_bound.is_unbounded()){ + Ok(WindowFrame{ + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + end_bound: WindowFrameBound::Following(ScalarValue::Null) + }) + } else { + Err(DataFusionError::Plan(format!( + "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) + } } else { Ok(window_frame) } From a4f201d6cfddeadccc7ce4ecae07f3164f62b595 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 14 Feb 2023 10:22:07 +0300 Subject: [PATCH 16/18] Fix error --- datafusion/core/tests/sql/window.rs | 67 +++++++++++++++--------- datafusion/sql/src/expr/function.rs | 14 ++--- datafusion/sql/tests/integration_test.rs | 21 -------- 3 files changed, 50 insertions(+), 52 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3129c1492e2c..b3069f37a683 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -621,22 +621,21 @@ async fn window_frame_ranges_unbounded_preceding_following() -> Result<()> { register_aggregate_csv(&ctx).await?; let sql = "SELECT \ SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as sum1, \ - COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1, \ - COUNT(c1) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as cnt2 + COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1 \ FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+------+------+", - "| sum1 | cnt1 | cnt2 |", - "+------+------+------+", - "| 285 | 100 | 100 |", - "| 123 | 63 | 100 |", - "| 285 | 100 | 100 |", - "| 123 | 63 | 100 |", - "| 123 | 63 | 100 |", - "+------+------+------+", + "+------+------+", + "| sum1 | cnt1 |", + "+------+------+", + "| 285 | 100 |", + "| 123 | 63 |", + "| 285 | 100 |", + "| 123 | 63 |", + "| 123 | 63 |", + "+------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -1051,19 +1050,6 @@ async fn window_frame_creation() -> Result<()> { "Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" ); - let df = ctx - .sql( - "SELECT - COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN '1 DAY' PRECEDING AND '2 DAY' FOLLOWING) - FROM aggregate_test_100;", - ) - .await?; - let results = df.collect().await; - assert_contains!( - results.err().unwrap().to_string(), - "Internal error: Operator - is not implemented for types UInt32(1) and Utf8(\"1 DAY\"). This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker" - ); - let df = ctx .sql( "SELECT @@ -1150,6 +1136,39 @@ async fn test_window_row_number_aggregate() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_window_range_equivalent_frames() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + COUNT(*) OVER(ORDER BY c9, c1 RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt1, + COUNT(*) OVER(ORDER BY c9, c1 RANGE UNBOUNDED PRECEDING) AS cnt2, + COUNT(*) OVER(ORDER BY c9, c1 RANGE CURRENT ROW) AS cnt3, + COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt4, + COUNT(*) OVER(RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt5, + COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS cnt6 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+------+------+------+------+------+------+", + "| c9 | cnt1 | cnt2 | cnt3 | cnt4 | cnt5 | cnt6 |", + "+-----------+------+------+------+------+------+------+", + "| 28774375 | 1 | 1 | 1 | 100 | 100 | 100 |", + "| 63044568 | 1 | 2 | 1 | 100 | 100 | 100 |", + "| 141047417 | 1 | 3 | 1 | 100 | 100 | 100 |", + "| 141680161 | 1 | 4 | 1 | 100 | 100 | 100 |", + "| 145294611 | 1 | 5 | 1 | 100 | 100 | 100 |", + "+-----------+------+------+------+------+------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn test_window_cume_dist() -> Result<()> { let config = SessionConfig::new(); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 70015bbaea51..f1343062a047 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -17,7 +17,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, @@ -72,12 +72,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Construct equivalent window frame. Only in below expressions for RANGE ORDER BY is not a requirement // In the downstream we will assume that RANGE contains ORDER BY clause. if (window_frame.start_bound.is_unbounded() || window_frame.start_bound == WindowFrameBound::CurrentRow) - && (window_frame.end_bound == WindowFrameBound::CurrentRow || window_frame.end_bound.is_unbounded()){ - Ok(WindowFrame{ - units: WindowFrameUnits::Rows, - start_bound: WindowFrameBound::Preceding(ScalarValue::Null), - end_bound: WindowFrameBound::Following(ScalarValue::Null) - }) + && (window_frame.end_bound == WindowFrameBound::CurrentRow || window_frame.end_bound.is_unbounded()) { + if order_by.is_empty() { + Ok(WindowFrame::new(false)) + } else { + Ok(window_frame) + } } else { Err(DataFusionError::Plan(format!( "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index c75f93d36932..44c0559ef35a 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2101,27 +2101,6 @@ fn over_order_by_with_window_frame_single_end() { quick_test(sql, expected); } -#[test] -fn over_order_by_with_window_frame_range_order_by_check() { - let sql = "SELECT order_id, MAX(qty) OVER (RANGE UNBOUNDED PRECEDING) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 0\")", - format!("{err:?}") - ); -} - -#[test] -fn over_order_by_with_window_frame_range_order_by_check_2() { - let sql = - "SELECT order_id, MAX(qty) OVER (ORDER BY order_id, qty RANGE UNBOUNDED PRECEDING) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 2\")", - format!("{err:?}") - ); -} - #[test] fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; From d60c8965e844ff1a8776c3bdcc3fe48125fb8534 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 14 Feb 2023 15:26:22 -0600 Subject: [PATCH 17/18] Move a check from execution to planning, reduce code duplication --- datafusion/core/src/physical_plan/planner.rs | 2 +- datafusion/core/tests/sql/window.rs | 25 +++++------ datafusion/expr/src/window_frame.rs | 30 ++++++++++++++ .../src/window/window_frame_state.rs | 41 +++++++++---------- .../proto/src/logical_plan/from_proto.rs | 22 +++++----- datafusion/sql/src/expr/function.rs | 25 ++--------- 6 files changed, 77 insertions(+), 68 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index b6269a560386..d0ee38ac90fd 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1594,7 +1594,7 @@ pub fn create_window_expr_with_name( }) .collect::>>()?; if !is_window_valid(window_frame) { - return Err(DataFusionError::Execution(format!( + return Err(DataFusionError::Plan(format!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound ))); diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b3069f37a683..5476652ade07 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -982,19 +982,20 @@ async fn window_frame_groups_multiple_order_columns() -> Result<()> { async fn window_frame_groups_without_order_by() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; - // execute the query - let df = ctx + // Try executing an erroneous query (the ORDER BY clause is missing in the + // window frame): + let err = ctx .sql( "SELECT SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM aggregate_test_100 ORDER BY c9;", ) - .await?; - let err = df.collect().await.unwrap_err(); + .await + .unwrap_err(); assert_contains!( err.to_string(), - "Execution error: GROUPS mode requires an ORDER BY clause".to_owned() + "Error during planning: GROUPS mode requires an ORDER BY clause".to_owned() ); Ok(()) } @@ -1034,7 +1035,7 @@ async fn window_frame_creation() -> Result<()> { let results = df.collect().await; assert_eq!( results.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" + "Error during planning: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" ); let df = ctx @@ -1047,20 +1048,20 @@ async fn window_frame_creation() -> Result<()> { let results = df.collect().await; assert_eq!( results.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" + "Error during planning: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" ); - let df = ctx + let err = ctx .sql( "SELECT COUNT(c1) OVER(GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM aggregate_test_100;", ) - .await?; - let results = df.collect().await; + .await + .unwrap_err(); assert_contains!( - results.err().unwrap().to_string(), - "Execution error: GROUPS mode requires an ORDER BY clause" + err.to_string(), + "Error during planning: GROUPS mode requires an ORDER BY clause" ); Ok(()) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index bf74d02b7005..c25d2491e45a 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -144,6 +144,36 @@ impl WindowFrame { } } +/// Construct equivalent explicit window frames for implicit corner cases. +/// With this processing, we may assume in downstream code that RANGE/GROUPS +/// frames contain an appropriate ORDER BY clause. +pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // Normally, RANGE frames require an ORDER BY clause with exactly one + // column. However, an ORDER BY clause may be absent in two edge cases. + if (frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + && (frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { + if order_bys == 0 { + frame.units = WindowFrameUnits::Rows; + frame.start_bound = + WindowFrameBound::Preceding(ScalarValue::UInt64(None)); + frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + } + } else { + return Err(DataFusionError::Plan(format!( + "With window frame of type RANGE, the ORDER BY expression must be of length 1, got {}", order_bys))); + } + } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { + return Err(DataFusionError::Plan( + "GROUPS mode requires an ORDER BY clause".to_string(), + )); + }; + Ok(frame) +} + /// There are five ways to describe starting and ending frame boundaries: /// /// 1. UNBOUNDED PRECEDING diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index b282a8103267..64abacde49c1 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -32,16 +32,18 @@ use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] pub enum WindowFrameContext<'a> { - // ROWS-frames are inherently stateless: + /// ROWS frames are inherently stateless. Rows(&'a Arc), - // RANGE-frames store window frame to calculate window frame boundaries - // In `state`, it keeps track of `last_range` calculated to increase search speed. + /// RANGE frames are stateful, they store indices specifying where the + /// previous search left off. This amortizes the overall cost to O(n) + /// where n denotes the row count. Range { window_frame: &'a Arc, state: WindowFrameStateRange, }, - // GROUPS-frames store window frame to calculate window frame boundaries - // In `state`, we store the boundaries of each group from the start of the table. + /// GROUPS frames are stateful, they store group boundaries and indices + /// specifying where the previous search left off. This amortizes the + /// overall cost to O(n) where n denotes the row count. Groups { window_frame: &'a Arc, state: WindowFrameStateGroups, @@ -49,7 +51,7 @@ pub enum WindowFrameContext<'a> { } impl<'a> WindowFrameContext<'a> { - /// Create a new default state for the given window frame. + /// Create a new state object for the given window frame. pub fn new( window_frame: &'a Arc, sort_options: Vec, @@ -156,12 +158,14 @@ impl<'a> WindowFrameContext<'a> { } } -/// This structure encapsulates all the state information we require as we -/// scan ranges of data while processing window frames. -/// `last_range` keeps the range calculated at the last search. Since we know that range only can progress forward -/// at the next search we start from `last_range`. This makes linear search amortized constant. -/// `sort_options` keeps the ordering of the columns in the ORDER BY clause. This information is used to calculate -/// range boundary, +/// This structure encapsulates all the state information we require as we scan +/// ranges of data while processing RANGE frames. Attribute `last_range` stores +/// the resulting indices from the previous search. Since the indices only +/// advance forward, we start from `last_range` subsequently. Thus, the overall +/// time complexity of linear search amortizes to O(n) where n denotes the total +/// row count. +/// Attribute `sort_options` stores the column ordering specified by the ORDER +/// BY clause. This information is used to calculate the range. #[derive(Debug, Default)] pub struct WindowFrameStateRange { last_range: Range, @@ -169,10 +173,10 @@ pub struct WindowFrameStateRange { } impl WindowFrameStateRange { - /// Creates new struct for range calculation + /// Create a new object to store the search state. fn new(sort_options: Vec, last_range: Range) -> Self { Self { - // Keeps the search range we last calculated + // Stores the search range we calculate for future use. last_range, sort_options, } @@ -242,7 +246,7 @@ impl WindowFrameStateRange { } } }; - // Store last calculated range, to start where we left of in the next iteration + // Store the resulting range so we can start from here subsequently: self.last_range.start = start; self.last_range.end = end; Ok(Range { start, end }) @@ -348,13 +352,6 @@ impl WindowFrameStateGroups { length: usize, idx: usize, ) -> Result> { - // Groups mode should have an ordering - // e.g `range_columns` shouldn't be empty - if range_columns.is_empty() { - return Err(DataFusionError::Execution( - "GROUPS mode requires an ORDER BY clause".to_string(), - )); - } let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { if n.is_null() { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a74874586ce5..af1391cd79d5 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -43,9 +43,10 @@ use datafusion_expr::{ regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, substring, tan, to_hex, to_timestamp_micros, to_timestamp_millis, - to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, - Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, - GetIndexedField, GroupingSet, + to_timestamp_seconds, translate, trim, trunc, upper, uuid, + window_frame::regularize, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -907,16 +908,13 @@ pub fn parse_expr( .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame: WindowFrame = window_frame.clone().try_into()?; - if WindowFrameUnits::Range == window_frame.units - && order_by.len() != 1 - { - Err(proto_error("With window frame of type RANGE, the order by expression must be of length 1")) - } else { - Ok(window_frame) - } + let window_frame = window_frame.clone().try_into()?; + regularize(window_frame, order_by.len()) }) - .transpose()?.ok_or_else(||{DataFusionError::Execution("expects somothing".to_string())})?; + .transpose()? + .ok_or_else(|| { + DataFusionError::Execution("expects something".to_string()) + })?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index f1343062a047..c5f23213aa31 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,9 +19,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFunction, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, @@ -65,26 +66,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame .as_ref() .map(|window_frame| { - let window_frame: WindowFrame = window_frame.clone().try_into()?; - if WindowFrameUnits::Range == window_frame.units - && order_by.len() != 1 - { - // Construct equivalent window frame. Only in below expressions for RANGE ORDER BY is not a requirement - // In the downstream we will assume that RANGE contains ORDER BY clause. - if (window_frame.start_bound.is_unbounded() || window_frame.start_bound == WindowFrameBound::CurrentRow) - && (window_frame.end_bound == WindowFrameBound::CurrentRow || window_frame.end_bound.is_unbounded()) { - if order_by.is_empty() { - Ok(WindowFrame::new(false)) - } else { - Ok(window_frame) - } - } else { - Err(DataFusionError::Plan(format!( - "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) - } - } else { - Ok(window_frame) - } + let window_frame = window_frame.clone().try_into()?; + regularize(window_frame, order_by.len()) }) .transpose()?; let window_frame = if let Some(window_frame) = window_frame { From 9df0b830677314ab2dbe854fca2905d55d2d679b Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 16 Feb 2023 15:23:58 -0600 Subject: [PATCH 18/18] Incorporate review suggestion (with cargo fmt fix) --- datafusion/proto/src/logical_plan/from_proto.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index af1391cd79d5..498563b2ab47 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -913,7 +913,9 @@ pub fn parse_expr( }) .transpose()? .ok_or_else(|| { - DataFusionError::Execution("expects something".to_string()) + DataFusionError::Execution( + "missing window frame during deserialization".to_string(), + ) })?; match window_function {