-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support unparsing LogicalPlan::Window nodes #10767
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,8 +236,8 @@ impl Unparser<'_> { | |
.map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let start_bound = self.convert_bound(&window_frame.start_bound); | ||
let end_bound = self.convert_bound(&window_frame.end_bound); | ||
let start_bound = self.convert_bound(&window_frame.start_bound)?; | ||
let end_bound = self.convert_bound(&window_frame.end_bound)?; | ||
let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec { | ||
window_name: None, | ||
partition_by: partition_by | ||
|
@@ -513,20 +513,30 @@ impl Unparser<'_> { | |
fn convert_bound( | ||
&self, | ||
bound: &datafusion_expr::window_frame::WindowFrameBound, | ||
) -> ast::WindowFrameBound { | ||
) -> Result<ast::WindowFrameBound> { | ||
match bound { | ||
datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => { | ||
ast::WindowFrameBound::Preceding( | ||
self.scalar_to_sql(val).map(Box::new).ok(), | ||
) | ||
Ok(ast::WindowFrameBound::Preceding({ | ||
let val = self.scalar_to_sql(val)?; | ||
if let ast::Expr::Value(ast::Value::Null) = &val { | ||
None | ||
} else { | ||
Some(Box::new(val)) | ||
} | ||
})) | ||
} | ||
datafusion_expr::window_frame::WindowFrameBound::Following(val) => { | ||
ast::WindowFrameBound::Following( | ||
self.scalar_to_sql(val).map(Box::new).ok(), | ||
) | ||
Ok(ast::WindowFrameBound::Following({ | ||
let val = self.scalar_to_sql(val)?; | ||
if let ast::Expr::Value(ast::Value::Null) = &val { | ||
None | ||
} else { | ||
Some(Box::new(val)) | ||
} | ||
})) | ||
} | ||
datafusion_expr::window_frame::WindowFrameBound::CurrentRow => { | ||
ast::WindowFrameBound::CurrentRow | ||
Ok(ast::WindowFrameBound::CurrentRow) | ||
} | ||
} | ||
} | ||
|
@@ -1148,7 +1158,7 @@ mod tests { | |
window_frame: WindowFrame::new(None), | ||
null_treatment: None, | ||
}), | ||
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL FOLLOWING)"#, | ||
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment on L520 for explanation of this test change. |
||
), | ||
( | ||
Expr::WindowFunction(WindowFunction { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,16 +20,24 @@ use datafusion_common::{ | |
tree_node::{Transformed, TreeNode}, | ||
Result, | ||
}; | ||
use datafusion_expr::{Aggregate, Expr, LogicalPlan}; | ||
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; | ||
|
||
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists | ||
/// One of the possible aggregation plans which can be found within a single select query. | ||
pub(crate) enum AggVariant<'a> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes a SELECT query can exclusively have only a window function or an aggregate function but not both. A LogicalPlan can certainly have both, but I could not find an example of a single SELECT query without any nesting/derived table factors that is allowed to have both. |
||
Aggregate(&'a Aggregate), | ||
Window(Vec<&'a Window>), | ||
} | ||
|
||
/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists | ||
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). | ||
/// If an Aggregate node is not found prior to this or at all before reaching the end | ||
/// of the tree, None is returned. | ||
pub(crate) fn find_agg_node_within_select( | ||
plan: &LogicalPlan, | ||
/// If an Aggregate or window node is not found prior to this or at all before reaching the end | ||
/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both | ||
/// be found in a single select query. | ||
pub(crate) fn find_agg_node_within_select<'a>( | ||
plan: &'a LogicalPlan, | ||
mut prev_windows: Option<AggVariant<'a>>, | ||
already_projected: bool, | ||
) -> Option<&Aggregate> { | ||
) -> Option<AggVariant<'a>> { | ||
// Note that none of the nodes that have a corresponding agg node can have more | ||
// than 1 input node. E.g. Projection / Filter always have 1 input node. | ||
let input = plan.inputs(); | ||
|
@@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select( | |
} else { | ||
input.first()? | ||
}; | ||
// Agg nodes explicitly return immediately with a single node | ||
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection | ||
if let LogicalPlan::Aggregate(agg) = input { | ||
Some(agg) | ||
Some(AggVariant::Aggregate(agg)) | ||
} else if let LogicalPlan::Window(window) = input { | ||
prev_windows = match &mut prev_windows { | ||
Some(AggVariant::Window(windows)) => { | ||
windows.push(window); | ||
prev_windows | ||
} | ||
_ => Some(AggVariant::Window(vec![window])), | ||
}; | ||
find_agg_node_within_select(input, prev_windows, already_projected) | ||
} else if let LogicalPlan::TableScan(_) = input { | ||
None | ||
prev_windows | ||
} else if let LogicalPlan::Projection(_) = input { | ||
if already_projected { | ||
None | ||
prev_windows | ||
} else { | ||
find_agg_node_within_select(input, true) | ||
find_agg_node_within_select(input, prev_windows, true) | ||
} | ||
} else { | ||
find_agg_node_within_select(input, already_projected) | ||
find_agg_node_within_select(input, prev_windows, already_projected) | ||
} | ||
} | ||
|
||
|
@@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> | |
}) | ||
.map(|e| e.data) | ||
} | ||
|
||
/// Recursively identify all Column expressions and transform them into the appropriate | ||
/// window expression contained in window. | ||
/// | ||
/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed | ||
/// into an actual window expression as identified in the window node. | ||
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> { | ||
expr.clone() | ||
.transform(|sub_expr| { | ||
if let Expr::Column(c) = sub_expr { | ||
if let Some(unproj) = windows | ||
.iter() | ||
.flat_map(|w| w.window_expr.iter()) | ||
.find(|window_expr| window_expr.display_name().unwrap() == c.name) | ||
{ | ||
Ok(Transformed::yes(unproj.clone())) | ||
} else { | ||
Ok(Transformed::no(Expr::Column(c))) | ||
} | ||
} else { | ||
Ok(Transformed::no(sub_expr)) | ||
} | ||
}) | ||
.map(|e| e.data) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> { | |
UNION ALL | ||
SELECT j2_string as string FROM j2 | ||
ORDER BY string DESC | ||
LIMIT 10"# | ||
LIMIT 10"#, | ||
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The roundtrip test will fail if SELECT id,
count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name,
sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person vs SELECT id,
count(*) over (PARTITION BY first_name),
last_name,
sum(id) over (PARTITION BY first_name),
first_name from person While the two plans are not identical There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That is my understanding as well We could plausibly simply the resulting plan of the window bounds are the default Can we also add some tests that have aggregate and window functions? Something like SELECT id, count(distinct id), sum(id) OVER (PARTITION BY first_name) from person
SELECT id, sum(id) OVER (PARTITION BY first_name ROWS 5 PRECEDING ROWS 2 FOLLOWING) from person There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that the datafusion planner does not support mixing aggregate and window functions. It does allow mixing window functions with different WindowSpecs, including some over all rows (which is almost the same thing as an aggregate function). I think this behavior makes sense as an aggregate function is strict in how many tuples it will return (one per group) while a window function can return multiple tuples per group as needed. DataFusion CLI v38.0.0
> create table person (id int, name varchar);
0 row(s) fetched.
Elapsed 0.001 seconds.
> insert into person values (1, 'a'), (2, 'b'), (3, 'c');
+-------+
| count |
+-------+
| 3 |
+-------+
1 row(s) fetched.
Elapsed 0.001 seconds.
> select count(distinct id), sum(id) over (partition by name) from person;
Error during planning: Projection references non-aggregate values: Expression person.id could not be resolved from available columns: COUNT(DISTINCT person.id)
> select count(distinct id) over (), sum(id) over (partition by name) from person;
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
| COUNT(person.id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING | SUM(person.id) PARTITION BY [person.name] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING |
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
| 3 | 2 |
| 3 | 1 |
| 3 | 3 |
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
3 row(s) fetched.
Elapsed 0.002 seconds. I added some tests and made a few changes to correctly support unparsing a SELECT query with multiple different WindowSpecs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests look great now -- thanks @devinjdangelo |
||
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), | ||
first_name from person", | ||
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), | ||
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, | ||
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", | ||
]; | ||
|
||
// For each test sql string, we transform as follows: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a subtle difference in how datafusion plans a window frame bound that is
None
vs ScalarValue::Null.The former yields
PRECEDING UNBOUNDED
and the latter yieldsPRECEDING NULL
.Datafusion's planner accepts the former, but rejects the latter.