Skip to content

Commit

Permalink
Fix output_partitioning(), output_ordering(), equivalence_properties(…
Browse files Browse the repository at this point in the history
…) in WindowAggExec, shift the Column indexes (#4455)

* Fix output_partitioning(), output_ordering(), equivalence_properties() in WindowAggExec, shift the Column indexes

* resolve review comments
  • Loading branch information
mingmwang authored Dec 2, 2022
1 parent 9bee14e commit 40e44a6
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 7 deletions.
94 changes: 87 additions & 7 deletions datafusion/core/src/physical_plan/windows/window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
use crate::physical_plan::{
ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties,
Column, ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties,
ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
SendableRecordBatchStream, Statistics, WindowExpr,
};
Expand All @@ -35,6 +35,8 @@ use arrow::{
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
};
use datafusion_physical_expr::rewrite::TreeNodeRewritable;
use datafusion_physical_expr::EquivalentClass;
use futures::stream::Stream;
use futures::{ready, StreamExt};
use log::debug;
Expand All @@ -58,6 +60,8 @@ pub struct WindowAggExec {
pub partition_keys: Vec<Arc<dyn PhysicalExpr>>,
/// Sort Keys
pub sort_keys: Option<Vec<PhysicalSortExpr>>,
/// The output ordering
output_ordering: Option<Vec<PhysicalSortExpr>>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
}
Expand All @@ -73,13 +77,42 @@ impl WindowAggExec {
) -> Result<Self> {
let schema = create_schema(&input_schema, &window_expr)?;
let schema = Arc::new(schema);
let window_expr_len = window_expr.len();
// Although WindowAggExec does not change the output ordering from the input, but can not return the output ordering
// from the input directly, need to adjust the column index to align with the new schema.
let output_ordering = input
.output_ordering()
.map(|sort_exprs| {
let new_sort_exprs: Result<Vec<PhysicalSortExpr>> = sort_exprs
.iter()
.map(|e| {
let new_expr = e.expr.clone().transform_down(&|e| {
Ok(e.as_any().downcast_ref::<Column>().map(|col| {
Arc::new(Column::new(
col.name(),
window_expr_len + col.index(),
))
as Arc<dyn PhysicalExpr>
}))
})?;
Ok(PhysicalSortExpr {
expr: new_expr,
options: e.options,
})
})
.collect();
new_sort_exprs
})
.map_or(Ok(None), |v| v.map(Some))?;

Ok(Self {
input,
window_expr,
schema,
input_schema,
partition_keys,
sort_keys,
output_ordering,
metrics: ExecutionPlanMetricsSet::new(),
})
}
Expand Down Expand Up @@ -116,14 +149,38 @@ impl ExecutionPlan for WindowAggExec {

/// Get the output partitioning of this plan
fn output_partitioning(&self) -> Partitioning {
// because we can have repartitioning using the partition keys
// this would be either 1 or more than 1 depending on the presense of
// repartitioning
self.input.output_partitioning()
// Although WindowAggExec does not change the output partitioning from the input, but can not return the output partitioning
// from the input directly, need to adjust the column index to align with the new schema.
let window_expr_len = self.window_expr.len();
let input_partitioning = self.input.output_partitioning();
match input_partitioning {
Partitioning::RoundRobinBatch(size) => Partitioning::RoundRobinBatch(size),
Partitioning::UnknownPartitioning(size) => {
Partitioning::UnknownPartitioning(size)
}
Partitioning::Hash(exprs, size) => {
let new_exprs = exprs
.into_iter()
.map(|expr| {
expr.transform_down(&|e| {
Ok(e.as_any().downcast_ref::<Column>().map(|col| {
Arc::new(Column::new(
col.name(),
window_expr_len + col.index(),
))
as Arc<dyn PhysicalExpr>
}))
})
.unwrap()
})
.collect::<Vec<_>>();
Partitioning::Hash(new_exprs, size)
}
}
}

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
self.input.output_ordering()
self.output_ordering.as_deref()
}

fn maintains_input_order(&self) -> bool {
Expand All @@ -146,7 +203,30 @@ impl ExecutionPlan for WindowAggExec {
}

fn equivalence_properties(&self) -> EquivalenceProperties {
self.input.equivalence_properties()
// Although WindowAggExec does not change the equivalence properties from the input, but can not return the equivalence properties
// from the input directly, need to adjust the column index to align with the new schema.
let window_expr_len = self.window_expr.len();
let mut new_properties = EquivalenceProperties::new(self.schema());
let new_eq_classes = self
.input
.equivalence_properties()
.classes()
.iter()
.map(|prop| {
let new_head = Column::new(
prop.head().name(),
window_expr_len + prop.head().index(),
);
let new_others = prop
.others()
.iter()
.map(|col| Column::new(col.name(), window_expr_len + col.index()))
.collect::<Vec<_>>();
EquivalentClass::new(new_head, new_others)
})
.collect::<Vec<_>>();
new_properties.extend(new_eq_classes);
new_properties
}

fn with_new_children(
Expand Down
37 changes: 37 additions & 0 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,40 @@ async fn test_window_frame_nth_value_aggregate() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn test_window_agg_sort() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT
c9,
SUM(c9) OVER(ORDER BY c9) as sum1,
SUM(c9) OVER(ORDER BY c9, c8) as sum2
FROM aggregate_test_100";

let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
// Only 1 SortExec was added
let expected = {
vec![
"ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST]@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST]@1 as sum2]",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" SortExec: [c9@1 ASC NULLS LAST,c8@0 ASC NULLS LAST]",
]
};

let actual: Vec<&str> = formatted.trim().lines().collect();
let actual_len = actual.len();
let actual_trim_last = &actual[..actual_len - 1];
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);
Ok(())
}

0 comments on commit 40e44a6

Please sign in to comment.