Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[window function] support min max with self define sliding window and optimize segment tree . #4616

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions datafusion/core/benches/window_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ fn criterion_benchmark(c: &mut Criterion) {
})
},
);

c.bench_function(
"window order by, u64_narrow, sum functions",
|b| {
b.iter(|| {
query(
ctx.clone(),
"SELECT \
SUM(u64_narrow) OVER (ORDER by u64_narrow desc ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING) \
FROM t",
)
})
},
);
}

criterion_group!(benches, criterion_benchmark);
Expand Down
29 changes: 29 additions & 0 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,35 @@ async fn test_window_frame_nth_value_aggregate() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_sliding_window_frame_min_max_aggregate() -> Result<()> {
let config = SessionConfig::new();
let ctx = SessionContext::with_config(config);
register_aggregate_csv(&ctx).await?;

let sql = "SELECT
MIN(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) as MIN,
MAX(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) as MAX
FROM aggregate_test_100
ORDER BY c9
LIMIT 5";

let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+--------+-------+",
"| min | max |",
"+--------+-------+",
"| -16110 | 3917 |",
"| -16974 | 15673 |",
"| -16974 | 15673 |",
"| -16974 | 15673 |",
"| -16974 | 20690 |",
"+--------+-------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn test_window_agg_sort() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
11 changes: 11 additions & 0 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ pub enum WindowFrameBound {
Following(ScalarValue),
}

impl WindowFrameBound {
/// check the frame is UNBOUNDED or not
pub fn is_unbounded(&self) -> bool {
match self {
WindowFrameBound::Preceding(x) => x.eq(&ScalarValue::Null),
WindowFrameBound::CurrentRow => false,
WindowFrameBound::Following(x) => x.eq(&ScalarValue::Null),
Comment on lines +154 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
WindowFrameBound::Preceding(x) => x.eq(&ScalarValue::Null),
WindowFrameBound::CurrentRow => false,
WindowFrameBound::Following(x) => x.eq(&ScalarValue::Null),
WindowFrameBound::Preceding(x) => x.is_null(),
WindowFrameBound::CurrentRow => false,
WindowFrameBound::Following(x) => x.is_null(),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using x.is_null() so that it also catches typed nulls (like ScalarValue::UInt6(None))

}
}
}

impl TryFrom<ast::WindowFrameBound> for WindowFrameBound {
type Error = DataFusionError;

Expand Down
99 changes: 74 additions & 25 deletions datafusion/physical-expr/src/window/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use arrow::{array::ArrayRef, datatypes::Field};

use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::WindowFrame;
use datafusion_expr::{Accumulator, WindowFrame};

use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::window::segment_tree::{Operator, SegmentTree};
use crate::{expressions, expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};

use super::window_frame_state::WindowFrameContext;
Expand Down Expand Up @@ -96,11 +97,24 @@ impl WindowExpr for AggregateWindowExpr {
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results: Vec<ScalarValue> = vec![];
for partition_range in &partition_points {
let mut accumulator = self.aggregate.create_accumulator()?;
let length = partition_range.end - partition_range.start;
let (values, order_bys) =
self.get_values_orderbys(&batch.slice(partition_range.start, length))?;

let mut accumulator = if let Some(opera) =
need_use_segment_tree(&self.window_frame, &self.aggregate)
{
let scalar_values = (0..values[0].len())
.map(|i| ScalarValue::try_from_array(&values[0], i))
.collect::<Result<Vec<_>>>()?;
WindowAccumulator::SegTree(SegmentTree::build(
scalar_values,
opera,
values[0].data_type().clone(),
)?)
} else {
WindowAccumulator::Default(self.aggregate.create_accumulator()?)
};
let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let mut last_range: (usize, usize) = (0, 0);

Expand All @@ -113,30 +127,40 @@ impl WindowExpr for AggregateWindowExpr {
length,
i,
)?;
let value = if cur_range.0 == cur_range.1 {
// We produce None if the window is empty.
ScalarValue::try_from(self.aggregate.field()?.data_type())?
} else {
// Accumulate any new rows that have entered the window:
let update_bound = cur_range.1 - last_range.1;
if update_bound > 0 {
let update: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.1, update_bound))
.collect();
accumulator.update_batch(&update)?
let value;
match &mut accumulator {
WindowAccumulator::Default(acc) => {
if cur_range.0 == cur_range.1 {
// We produce None if the window is empty.
value = ScalarValue::try_from(
self.aggregate.field()?.data_type(),
)?
} else {
// Accumulate any new rows that have entered the window:
let update_bound = cur_range.1 - last_range.1;
if update_bound > 0 {
let update: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.1, update_bound))
.collect();
acc.update_batch(&update)?
}
// Remove rows that have now left the window:
let retract_bound = cur_range.0 - last_range.0;
if retract_bound > 0 {
let retract: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.0, retract_bound))
.collect();
acc.retract_batch(&retract)?
}
value = acc.evaluate()?
};
}
// Remove rows that have now left the window:
let retract_bound = cur_range.0 - last_range.0;
if retract_bound > 0 {
let retract: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.0, retract_bound))
.collect();
accumulator.retract_batch(&retract)?
WindowAccumulator::SegTree(tree) => {
value = tree.query(cur_range.0, cur_range.1)?
}
accumulator.evaluate()?
};
}
row_wise_results.push(value);
last_range = cur_range;
}
Expand All @@ -156,3 +180,28 @@ impl WindowExpr for AggregateWindowExpr {
&self.window_frame
}
}
enum WindowAccumulator {
Default(Box<dyn Accumulator>),
SegTree(SegmentTree),
}

// Only using segment tree in Min, Max, Sum with sliding window(Both left side and right side need to move)
fn need_use_segment_tree(
frame: &Arc<WindowFrame>,
agg: &Arc<dyn AggregateExpr>,
) -> Option<Operator> {
if !frame.start_bound.is_unbounded() && !frame.end_bound.is_unbounded() {
let agg_any = agg.as_any();
if agg_any.downcast_ref::<expressions::Sum>().is_some() {
Some(Operator::Add)
} else if agg_any.downcast_ref::<expressions::Min>().is_some() {
Some(Operator::Min)
} else if agg_any.downcast_ref::<expressions::Max>().is_some() {
Some(Operator::Max)
} else {
None
}
} else {
None
}
}
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub(crate) mod nth_value;
pub(crate) mod partition_evaluator;
pub(crate) mod rank;
pub(crate) mod row_number;
pub(crate) mod segment_tree;
mod window_expr;
mod window_frame_state;

Expand Down
170 changes: 170 additions & 0 deletions datafusion/physical-expr/src/window/segment_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::aggregate::min_max::{max, min};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue};

// A Enum that specifies which operator could use in segment tree.
pub enum Operator {
Add,
Min,
Max,
}

impl Operator {
// The operation that is performed to combine two intervals in the segment tree.
//
// This function must be associative, that is `combine(combine(a, b), c) =
// combine(a, combine(b, c))`.
pub fn combine(&self, a: &ScalarValue, b: &ScalarValue) -> Result<ScalarValue> {
match self {
Operator::Add => a.add(b),
Operator::Min => min(a, b),
Operator::Max => max(a, b),
}
}
}
// A segment tree is a binary tree where each node contains the combination of the
// children under the operation.
pub struct SegmentTree {
buf: Vec<ScalarValue>,
count: usize,
op: Operator,
data_type: DataType,
}

impl SegmentTree {
// Builds a tree using the given buffer with ScalarValues.
pub fn build(
mut buf: Vec<ScalarValue>,
op: Operator,
data_type: DataType,
) -> Result<Self> {
let len = buf.len();
buf.reserve_exact(len);
for i in 0..len {
let clone = unsafe { buf.get_unchecked(i).clone() }; // SAFETY: will never out of bound.
buf.push(clone);
}
SegmentTree::build_inner(buf, op, data_type)
}

fn build_inner(
mut buf: Vec<ScalarValue>,
op: Operator,
data_type: DataType,
) -> Result<Self> {
let len = buf.len();
let count = len >> 1;
if len & 1 == 1 {
panic!("SegmentTree::build_inner: odd size");
}
for i in (1..count).rev() {
let res = op.combine(&buf[i << 1], &buf[i << 1 | 1])?;
buf[i] = res;
}
Ok(SegmentTree {
buf,
count,
op,
data_type,
})
}

// Computes `a[l] op a[l+1] op ... op a[r-1]`.
// Uses `O(log(len))` time.
// If `l > r`, this method returns error.
// If `l == r`, this method returns Null.
pub fn query(&self, mut l: usize, mut r: usize) -> Result<ScalarValue> {
if l > r {
return Err(DataFusionError::Internal(
"Query SegmentTree l must <= r".to_string(),
));
}
let mut res = ScalarValue::try_from(&self.data_type)?;
l += self.count;
r += self.count;
while l < r {
if l & 1 == 1 {
res = self.op.combine(&res, &self.buf[l])?;
l += 1;
}
if r & 1 == 1 {
r -= 1;
res = self.op.combine(&res, &self.buf[r])?;
}
l >>= 1;
r >>= 1;
}
Ok(res)
}
}

#[cfg(test)]
mod tests {
use crate::window::segment_tree::{Operator, SegmentTree};
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
use rand::Rng;

#[test]
fn test_query_segment_tree() {
let test_size = 1000;
let val_range = 10000;
let mut rng = rand::thread_rng();
let rand_vals: Vec<i32> = (0..test_size)
.map(|_| rng.gen_range(0..val_range))
.collect();
let rand_scalar: Vec<ScalarValue> = rand_vals
.iter()
.map(|v| ScalarValue::from(v.clone()))
.collect();

let segment_tree_add =
SegmentTree::build(rand_scalar.clone(), Operator::Add, DataType::Int32)
.unwrap();

let segment_tree_min =
SegmentTree::build(rand_scalar.clone(), Operator::Min, DataType::Int32)
.unwrap();
let segment_tree_max =
SegmentTree::build(rand_scalar, Operator::Max, DataType::Int32).unwrap();

for _i in 0..1000 {
let start: usize = rng.gen_range(0..test_size - 1);
let end: usize = rng.gen_range(start + 1..test_size);

let add_result = segment_tree_add.query(start, end).unwrap();
let min_result = segment_tree_min.query(start, end).unwrap();
let max_result = segment_tree_max.query(start, end).unwrap();

assert_eq!(
add_result,
ScalarValue::from(rand_vals[start..end].iter().sum::<i32>())
);
assert_eq!(
min_result,
ScalarValue::from(rand_vals[start..end].iter().min().unwrap().clone())
);
assert_eq!(
max_result,
ScalarValue::from(rand_vals[start..end].iter().max().unwrap().clone())
);
}
}
}