Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support unparsing LogicalPlan::Window nodes #10767

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ impl Unparser<'_> {
.map(|expr| expr_to_unparsed(expr)?.into_order_by_expr())
.collect::<Result<Vec<_>>>()?;

let start_bound = self.convert_bound(&window_frame.start_bound);
let end_bound = self.convert_bound(&window_frame.end_bound);
let start_bound = self.convert_bound(&window_frame.start_bound)?;
let end_bound = self.convert_bound(&window_frame.end_bound)?;
let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
window_name: None,
partition_by: partition_by
Expand Down Expand Up @@ -513,20 +513,30 @@ impl Unparser<'_> {
fn convert_bound(
&self,
bound: &datafusion_expr::window_frame::WindowFrameBound,
) -> ast::WindowFrameBound {
) -> Result<ast::WindowFrameBound> {
match bound {
datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => {
ast::WindowFrameBound::Preceding(
self.scalar_to_sql(val).map(Box::new).ok(),
)
Ok(ast::WindowFrameBound::Preceding({
let val = self.scalar_to_sql(val)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a subtle difference in how datafusion plans a window frame bound that is None vs ScalarValue::Null.

The former yields PRECEDING UNBOUNDED and the latter yields PRECEDING NULL.

Datafusion's planner accepts the former, but rejects the latter.

if let ast::Expr::Value(ast::Value::Null) = &val {
None
} else {
Some(Box::new(val))
}
}))
}
datafusion_expr::window_frame::WindowFrameBound::Following(val) => {
ast::WindowFrameBound::Following(
self.scalar_to_sql(val).map(Box::new).ok(),
)
Ok(ast::WindowFrameBound::Following({
let val = self.scalar_to_sql(val)?;
if let ast::Expr::Value(ast::Value::Null) = &val {
None
} else {
Some(Box::new(val))
}
}))
}
datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
ast::WindowFrameBound::CurrentRow
Ok(ast::WindowFrameBound::CurrentRow)
}
}
}
Expand Down Expand Up @@ -1148,7 +1158,7 @@ mod tests {
window_frame: WindowFrame::new(None),
null_treatment: None,
}),
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL FOLLOWING)"#,
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment on L520 for explanation of this test change.

),
(
Expr::WindowFunction(WindowFunction {
Expand Down
71 changes: 48 additions & 23 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use super::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
utils::find_agg_node_within_select,
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};

Expand Down Expand Up @@ -162,23 +162,42 @@ impl Unparser<'_> {
// A second projection implies a derived tablefactor
if !select.already_projected() {
// Special handling when projecting an agregation plan
if let Some(agg) = find_agg_node_within_select(plan, true) {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
if let Some(aggvariant) =
find_agg_node_within_select(plan, None, true)
{
match aggvariant {
AggVariant::Aggregate(agg) => {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
}
AggVariant::Window(window) => {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj =
unproject_window_exprs(proj_expr, &window)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
}
}
} else {
let items = p
.expr
Expand Down Expand Up @@ -210,8 +229,8 @@ impl Unparser<'_> {
}
}
LogicalPlan::Filter(filter) => {
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
if let Some(AggVariant::Aggregate(agg)) =
find_agg_node_within_select(plan, None, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
Expand Down Expand Up @@ -265,7 +284,7 @@ impl Unparser<'_> {
)
}
LogicalPlan::Aggregate(agg) => {
// Aggregate nodes are handled simulatenously with Projection nodes
// Aggregate nodes are handled simultaneously with Projection nodes
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
Expand Down Expand Up @@ -441,8 +460,14 @@ impl Unparser<'_> {

Ok(())
}
LogicalPlan::Window(_window) => {
not_impl_err!("Unsupported operator: {plan:?}")
LogicalPlan::Window(window) => {
// Window nodes are handled simultaneously with Projection nodes
self.select_to_sql_recursively(
window.input.as_ref(),
query,
select,
relation,
)
}
LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"),
_ => not_impl_err!("Unsupported operator: {plan:?}"),
Expand Down
68 changes: 56 additions & 12 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,24 @@ use datafusion_common::{
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists
/// One of the possible aggregation plans which can be found within a single select query.
pub(crate) enum AggVariant<'a> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes a SELECT query can exclusively have only a window function or an aggregate function but not both. A LogicalPlan can certainly have both, but I could not find an example of a single SELECT query without any nesting/derived table factors that is allowed to have both.

Aggregate(&'a Aggregate),
Window(Vec<&'a Window>),
}

/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both
/// be found in a single select query.
pub(crate) fn find_agg_node_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<AggVariant<'a>>,
already_projected: bool,
) -> Option<&Aggregate> {
) -> Option<AggVariant<'a>> {
// Note that none of the nodes that have a corresponding agg node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
Expand All @@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select(
} else {
input.first()?
};
// Agg nodes explicitly return immediately with a single node
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
Some(AggVariant::Aggregate(agg))
} else if let LogicalPlan::Window(window) = input {
prev_windows = match &mut prev_windows {
Some(AggVariant::Window(windows)) => {
windows.push(window);
prev_windows
}
_ => Some(AggVariant::Window(vec![window])),
};
find_agg_node_within_select(input, prev_windows, already_projected)
} else if let LogicalPlan::TableScan(_) = input {
None
prev_windows
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
prev_windows
} else {
find_agg_node_within_select(input, true)
find_agg_node_within_select(input, prev_windows, true)
}
} else {
find_agg_node_within_select(input, already_projected)
find_agg_node_within_select(input, prev_windows, already_projected)
}
}

Expand Down Expand Up @@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr>
})
.map(|e| e.data)
}

/// Recursively identify all Column expressions and transform them into the appropriate
/// window expression contained in window.
///
/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
/// into an actual window expression as identified in the window node.
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|window_expr| window_expr.display_name().unwrap() == c.name)
{
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
8 changes: 7 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> {
UNION ALL
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#
LIMIT 10"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The roundtrip test will fail if ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING is not explicitly included. E.g. the datafusion planner generates a non identical plan for the following two SQL queries:

SELECT id, 
count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 
last_name, 
sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person

vs

SELECT id, 
count(*) over (PARTITION BY first_name), 
last_name, 
sum(id) over (PARTITION BY first_name),
first_name from person

While the two plans are not identical p1!=p2, I believe the difference is trivial and will actually result in the same computations taking place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the two plans are not identical p1!=p2, I believe the difference is trivial and will actually result in the same computations taking place.

That is my understanding as well

We could plausibly simply the resulting plan of the window bounds are the default

Can we also add some tests that have aggregate and window functions? Something like

SELECT id, count(distinct id), sum(id) OVER (PARTITION BY first_name) from person

SELECT id, sum(id) OVER (PARTITION BY first_name ROWS 5 PRECEDING ROWS 2 FOLLOWING) from person

Copy link
Contributor Author

@devinjdangelo devinjdangelo Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the datafusion planner does not support mixing aggregate and window functions. It does allow mixing window functions with different WindowSpecs, including some over all rows (which is almost the same thing as an aggregate function). I think this behavior makes sense as an aggregate function is strict in how many tuples it will return (one per group) while a window function can return multiple tuples per group as needed.

DataFusion CLI v38.0.0
> create table person (id int, name varchar);
0 row(s) fetched. 
Elapsed 0.001 seconds.
> insert into person values (1, 'a'), (2, 'b'), (3, 'c');
+-------+
| count |
+-------+
| 3     |
+-------+
1 row(s) fetched. 
Elapsed 0.001 seconds.
> select count(distinct id), sum(id) over (partition by name) from person;
Error during planning: Projection references non-aggregate values: Expression person.id could not be resolved from available columns: COUNT(DISTINCT person.id)
> select count(distinct id) over (), sum(id) over (partition by name) from person;
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
| COUNT(person.id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING | SUM(person.id) PARTITION BY [person.name] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING |
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
| 3                                                                         | 2                                                                                                  |
| 3                                                                         | 1                                                                                                  |
| 3                                                                         | 3                                                                                                  |
+---------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
3 row(s) fetched. 
Elapsed 0.002 seconds.

I added some tests and made a few changes to correctly support unparsing a SELECT query with multiple different WindowSpecs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests look great now -- thanks @devinjdangelo

last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
];

// For each test sql string, we transform as follows:
Expand Down