Skip to content

Commit

Permalink
feat: support unparsing LogicalPlan::Window nodes (#10767)
Browse files Browse the repository at this point in the history
* unparse window plans

* new tests + fixes

* fmt
  • Loading branch information
devinjdangelo authored Jun 3, 2024
1 parent 180f3e8 commit e4f7b98
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 47 deletions.
32 changes: 21 additions & 11 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ impl Unparser<'_> {
.map(|expr| expr_to_unparsed(expr)?.into_order_by_expr())
.collect::<Result<Vec<_>>>()?;

let start_bound = self.convert_bound(&window_frame.start_bound);
let end_bound = self.convert_bound(&window_frame.end_bound);
let start_bound = self.convert_bound(&window_frame.start_bound)?;
let end_bound = self.convert_bound(&window_frame.end_bound)?;
let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
window_name: None,
partition_by: partition_by
Expand Down Expand Up @@ -513,20 +513,30 @@ impl Unparser<'_> {
fn convert_bound(
&self,
bound: &datafusion_expr::window_frame::WindowFrameBound,
) -> ast::WindowFrameBound {
) -> Result<ast::WindowFrameBound> {
match bound {
datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => {
ast::WindowFrameBound::Preceding(
self.scalar_to_sql(val).map(Box::new).ok(),
)
Ok(ast::WindowFrameBound::Preceding({
let val = self.scalar_to_sql(val)?;
if let ast::Expr::Value(ast::Value::Null) = &val {
None
} else {
Some(Box::new(val))
}
}))
}
datafusion_expr::window_frame::WindowFrameBound::Following(val) => {
ast::WindowFrameBound::Following(
self.scalar_to_sql(val).map(Box::new).ok(),
)
Ok(ast::WindowFrameBound::Following({
let val = self.scalar_to_sql(val)?;
if let ast::Expr::Value(ast::Value::Null) = &val {
None
} else {
Some(Box::new(val))
}
}))
}
datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
ast::WindowFrameBound::CurrentRow
Ok(ast::WindowFrameBound::CurrentRow)
}
}
}
Expand Down Expand Up @@ -1148,7 +1158,7 @@ mod tests {
window_frame: WindowFrame::new(None),
null_treatment: None,
}),
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL FOLLOWING)"#,
r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#,
),
(
Expr::WindowFunction(WindowFunction {
Expand Down
71 changes: 48 additions & 23 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use super::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
utils::find_agg_node_within_select,
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};

Expand Down Expand Up @@ -162,23 +162,42 @@ impl Unparser<'_> {
// A second projection implies a derived tablefactor
if !select.already_projected() {
// Special handling when projecting an agregation plan
if let Some(agg) = find_agg_node_within_select(plan, true) {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
if let Some(aggvariant) =
find_agg_node_within_select(plan, None, true)
{
match aggvariant {
AggVariant::Aggregate(agg) => {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
}
AggVariant::Window(window) => {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj =
unproject_window_exprs(proj_expr, &window)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
}
}
} else {
let items = p
.expr
Expand Down Expand Up @@ -210,8 +229,8 @@ impl Unparser<'_> {
}
}
LogicalPlan::Filter(filter) => {
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
if let Some(AggVariant::Aggregate(agg)) =
find_agg_node_within_select(plan, None, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
Expand Down Expand Up @@ -265,7 +284,7 @@ impl Unparser<'_> {
)
}
LogicalPlan::Aggregate(agg) => {
// Aggregate nodes are handled simulatenously with Projection nodes
// Aggregate nodes are handled simultaneously with Projection nodes
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
Expand Down Expand Up @@ -441,8 +460,14 @@ impl Unparser<'_> {

Ok(())
}
LogicalPlan::Window(_window) => {
not_impl_err!("Unsupported operator: {plan:?}")
LogicalPlan::Window(window) => {
// Window nodes are handled simultaneously with Projection nodes
self.select_to_sql_recursively(
window.input.as_ref(),
query,
select,
relation,
)
}
LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"),
_ => not_impl_err!("Unsupported operator: {plan:?}"),
Expand Down
68 changes: 56 additions & 12 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,24 @@ use datafusion_common::{
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists
/// One of the possible aggregation plans which can be found within a single select query.
pub(crate) enum AggVariant<'a> {
Aggregate(&'a Aggregate),
Window(Vec<&'a Window>),
}

/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both
/// be found in a single select query.
pub(crate) fn find_agg_node_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<AggVariant<'a>>,
already_projected: bool,
) -> Option<&Aggregate> {
) -> Option<AggVariant<'a>> {
// Note that none of the nodes that have a corresponding agg node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
Expand All @@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select(
} else {
input.first()?
};
// Agg nodes explicitly return immediately with a single node
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
Some(AggVariant::Aggregate(agg))
} else if let LogicalPlan::Window(window) = input {
prev_windows = match &mut prev_windows {
Some(AggVariant::Window(windows)) => {
windows.push(window);
prev_windows
}
_ => Some(AggVariant::Window(vec![window])),
};
find_agg_node_within_select(input, prev_windows, already_projected)
} else if let LogicalPlan::TableScan(_) = input {
None
prev_windows
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
prev_windows
} else {
find_agg_node_within_select(input, true)
find_agg_node_within_select(input, prev_windows, true)
}
} else {
find_agg_node_within_select(input, already_projected)
find_agg_node_within_select(input, prev_windows, already_projected)
}
}

Expand Down Expand Up @@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr>
})
.map(|e| e.data)
}

/// Recursively identify all Column expressions and transform them into the appropriate
/// window expression contained in window.
///
/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
/// into an actual window expression as identified in the window node.
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|window_expr| window_expr.display_name().unwrap() == c.name)
{
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
8 changes: 7 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> {
UNION ALL
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#
LIMIT 10"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
];

// For each test sql string, we transform as follows:
Expand Down

0 comments on commit e4f7b98

Please sign in to comment.