Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear search support for Window Group queries #5286

Merged
merged 22 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ pub fn create_window_expr_with_name(
})
.collect::<Result<Vec<_>>>()?;
if !is_window_valid(window_frame) {
return Err(DataFusionError::Execution(format!(
return Err(DataFusionError::Plan(format!(
"Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
window_frame.start_bound, window_frame.end_bound
)));
Expand Down
61 changes: 54 additions & 7 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,19 +982,20 @@ async fn window_frame_groups_multiple_order_columns() -> Result<()> {
async fn window_frame_groups_without_order_by() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
// execute the query
let df = ctx
// Try executing an erroneous query (the ORDER BY clause is missing in the
// window frame):
let err = ctx
.sql(
"SELECT
SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100
ORDER BY c9;",
)
.await?;
let err = df.collect().await.unwrap_err();
.await
.unwrap_err();
assert_contains!(
err.to_string(),
"Execution error: GROUPS mode requires an ORDER BY clause".to_owned()
"Error during planning: GROUPS mode requires an ORDER BY clause".to_owned()
);
Ok(())
}
Expand Down Expand Up @@ -1034,7 +1035,7 @@ async fn window_frame_creation() -> Result<()> {
let results = df.collect().await;
assert_eq!(
results.err().unwrap().to_string(),
"Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)"
"Error during planning: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)"
);

let df = ctx
Expand All @@ -1047,7 +1048,20 @@ async fn window_frame_creation() -> Result<()> {
let results = df.collect().await;
assert_eq!(
results.err().unwrap().to_string(),
"Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)"
"Error during planning: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)"
);

let err = ctx
.sql(
"SELECT
COUNT(c1) OVER(GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM aggregate_test_100;",
)
.await
.unwrap_err();
assert_contains!(
err.to_string(),
"Error during planning: GROUPS mode requires an ORDER BY clause"
);

Ok(())
Expand Down Expand Up @@ -1123,6 +1137,39 @@ async fn test_window_row_number_aggregate() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_window_range_equivalent_frames() -> Result<()> {
let config = SessionConfig::new();
let ctx = SessionContext::with_config(config);
register_aggregate_csv(&ctx).await?;
let sql = "SELECT
c9,
COUNT(*) OVER(ORDER BY c9, c1 RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt1,
COUNT(*) OVER(ORDER BY c9, c1 RANGE UNBOUNDED PRECEDING) AS cnt2,
COUNT(*) OVER(ORDER BY c9, c1 RANGE CURRENT ROW) AS cnt3,
COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND CURRENT ROW) AS cnt4,
COUNT(*) OVER(RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt5,
COUNT(*) OVER(RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS cnt6
FROM aggregate_test_100
ORDER BY c9
LIMIT 5";

let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------+------+------+------+------+------+------+",
"| c9 | cnt1 | cnt2 | cnt3 | cnt4 | cnt5 | cnt6 |",
"+-----------+------+------+------+------+------+------+",
"| 28774375 | 1 | 1 | 1 | 100 | 100 | 100 |",
"| 63044568 | 1 | 2 | 1 | 100 | 100 | 100 |",
"| 141047417 | 1 | 3 | 1 | 100 | 100 | 100 |",
"| 141680161 | 1 | 4 | 1 | 100 | 100 | 100 |",
"| 145294611 | 1 | 5 | 1 | 100 | 100 | 100 |",
"+-----------+------+------+------+------+------+------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn test_window_cume_dist() -> Result<()> {
let config = SessionConfig::new();
Expand Down
30 changes: 30 additions & 0 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,36 @@ impl WindowFrame {
}
}

/// Construct equivalent explicit window frames for implicit corner cases.
/// With this processing, we may assume in downstream code that RANGE/GROUPS
/// frames contain an appropriate ORDER BY clause.
pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result<WindowFrame> {
if frame.units == WindowFrameUnits::Range && order_bys != 1 {
// Normally, RANGE frames require an ORDER BY clause with exactly one
// column. However, an ORDER BY clause may be absent in two edge cases.
if (frame.start_bound.is_unbounded()
|| frame.start_bound == WindowFrameBound::CurrentRow)
&& (frame.end_bound == WindowFrameBound::CurrentRow
|| frame.end_bound.is_unbounded())
{
if order_bys == 0 {
frame.units = WindowFrameUnits::Rows;
frame.start_bound =
WindowFrameBound::Preceding(ScalarValue::UInt64(None));
frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None));
}
} else {
return Err(DataFusionError::Plan(format!(
"With window frame of type RANGE, the ORDER BY expression must be of length 1, got {}", order_bys)));
}
} else if frame.units == WindowFrameUnits::Groups && order_bys == 0 {
return Err(DataFusionError::Plan(
"GROUPS mode requires an ORDER BY clause".to_string(),
));
};
Ok(frame)
}

/// There are five ways to describe starting and ending frame boundaries:
///
/// 1. UNBOUNDED PRECEDING
Expand Down
32 changes: 15 additions & 17 deletions datafusion/physical-expr/src/window/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,15 @@ impl WindowExpr for BuiltInWindowExpr {
let mut row_wise_results = vec![];

let (values, order_bys) = self.get_values_orderbys(batch)?;
let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let range = Range { start: 0, end: 0 };
let mut window_frame_ctx = WindowFrameContext::new(
&self.window_frame,
sort_options,
Range { start: 0, end: 0 },
);
// We iterate on each row to calculate window frame range and and window function result
for idx in 0..num_rows {
let range = window_frame_ctx.calculate_range(
&order_bys,
&sort_options,
num_rows,
idx,
&range,
)?;
let range =
window_frame_ctx.calculate_range(&order_bys, num_rows, idx)?;
let value = evaluator.evaluate_inside_range(&values, &range)?;
row_wise_results.push(value);
}
Expand Down Expand Up @@ -168,7 +166,13 @@ impl WindowExpr for BuiltInWindowExpr {
// We iterate on each row to perform a running calculation.
let record_batch = &partition_batch_state.record_batch;
let num_rows = record_batch.num_rows();
let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let last_range = state.window_frame_range.clone();
let mut window_frame_ctx = WindowFrameContext::new(
&self.window_frame,
sort_options.clone(),
// Start search from the last range
last_range,
);
let sort_partition_points = if evaluator.include_rank() {
let columns = self.sort_columns(record_batch)?;
self.evaluate_partition_points(num_rows, &columns)?
Expand All @@ -179,13 +183,7 @@ impl WindowExpr for BuiltInWindowExpr {
let mut last_range = state.window_frame_range.clone();
for idx in state.last_calculated_index..num_rows {
state.window_frame_range = if self.expr.uses_window_frame() {
window_frame_ctx.calculate_range(
&order_bys,
&sort_options,
num_rows,
idx,
&state.window_frame_range,
)
window_frame_ctx.calculate_range(&order_bys, num_rows, idx)
} else {
evaluator.get_range(state, num_rows)
}?;
Expand Down
27 changes: 8 additions & 19 deletions datafusion/physical-expr/src/window/window_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,10 @@ pub trait AggregateWindowExpr: WindowExpr {

/// Evaluates the window function against the batch.
fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
let mut accumulator = self.get_accumulator()?;
let mut last_range = Range { start: 0, end: 0 };
let mut idx = 0;
self.get_result_column(
&mut accumulator,
batch,
&mut window_frame_ctx,
&mut last_range,
&mut idx,
false,
)
self.get_result_column(&mut accumulator, batch, &mut last_range, &mut idx, false)
}

/// Statefully evaluates the window function against the batch. Maintains
Expand Down Expand Up @@ -207,11 +199,9 @@ pub trait AggregateWindowExpr: WindowExpr {
let mut state = &mut window_state.state;

let record_batch = &partition_batch_state.record_batch;
let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
let out_col = self.get_result_column(
accumulator,
record_batch,
&mut window_frame_ctx,
&mut state.window_frame_range,
&mut state.last_calculated_index,
!partition_batch_state.is_end,
Expand All @@ -230,7 +220,6 @@ pub trait AggregateWindowExpr: WindowExpr {
&self,
accumulator: &mut Box<dyn Accumulator>,
record_batch: &RecordBatch,
window_frame_ctx: &mut WindowFrameContext,
last_range: &mut Range<usize>,
idx: &mut usize,
not_end: bool,
Expand All @@ -240,15 +229,15 @@ pub trait AggregateWindowExpr: WindowExpr {
let length = values[0].len();
let sort_options: Vec<SortOptions> =
self.order_by().iter().map(|o| o.options).collect();
let mut window_frame_ctx = WindowFrameContext::new(
self.get_window_frame(),
sort_options,
// Start search from the last range
last_range.clone(),
);
let mut row_wise_results: Vec<ScalarValue> = vec![];
while *idx < length {
let cur_range = window_frame_ctx.calculate_range(
&order_bys,
&sort_options,
length,
*idx,
last_range,
)?;
let cur_range = window_frame_ctx.calculate_range(&order_bys, length, *idx)?;
// Exit if the range extends all the way:
if cur_range.end == length && not_end {
break;
Expand Down
Loading