From 40e44a64f51438ea8eae7112828ef116409a15f0 Mon Sep 17 00:00:00 2001 From: mingmwang Date: Sat, 3 Dec 2022 04:50:54 +0800 Subject: [PATCH] Fix output_partitioning(), output_ordering(), equivalence_properties() in WindowAggExec, shift the Column indexes (#4455) * Fix output_partitioning(), output_ordering(), equivalence_properties() in WindowAggExec, shift the Column indexes * resolve review comments --- .../physical_plan/windows/window_agg_exec.rs | 94 +++++++++++++++++-- datafusion/core/tests/sql/window.rs | 37 ++++++++ 2 files changed, 124 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 2e1b6a70b9da..5b4dc79eaf22 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -24,7 +24,7 @@ use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, }; use crate::physical_plan::{ - ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties, + Column, ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; @@ -35,6 +35,8 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; use futures::{ready, StreamExt}; use log::debug; @@ -58,6 +60,8 @@ pub struct WindowAggExec { pub partition_keys: Vec>, /// Sort Keys pub sort_keys: Option>, + /// The output ordering + output_ordering: Option>, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -73,6 +77,34 @@ impl WindowAggExec { ) -> Result { let schema = create_schema(&input_schema, &window_expr)?; let schema = Arc::new(schema); + let window_expr_len = window_expr.len(); + // Although WindowAggExec does not change the output ordering from the input, but can not return the output ordering + // from the input directly, need to adjust the column index to align with the new schema. + let output_ordering = input + .output_ordering() + .map(|sort_exprs| { + let new_sort_exprs: Result> = sort_exprs + .iter() + .map(|e| { + let new_expr = e.expr.clone().transform_down(&|e| { + Ok(e.as_any().downcast_ref::().map(|col| { + Arc::new(Column::new( + col.name(), + window_expr_len + col.index(), + )) + as Arc + })) + })?; + Ok(PhysicalSortExpr { + expr: new_expr, + options: e.options, + }) + }) + .collect(); + new_sort_exprs + }) + .map_or(Ok(None), |v| v.map(Some))?; + Ok(Self { input, window_expr, @@ -80,6 +112,7 @@ impl WindowAggExec { input_schema, partition_keys, sort_keys, + output_ordering, metrics: ExecutionPlanMetricsSet::new(), }) } @@ -116,14 +149,38 @@ impl ExecutionPlan for WindowAggExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - // because we can have repartitioning using the partition keys - // this would be either 1 or more than 1 depending on the presense of - // repartitioning - self.input.output_partitioning() + // Although WindowAggExec does not change the output partitioning from the input, but can not return the output partitioning + // from the input directly, need to adjust the column index to align with the new schema. + let window_expr_len = self.window_expr.len(); + let input_partitioning = self.input.output_partitioning(); + match input_partitioning { + Partitioning::RoundRobinBatch(size) => Partitioning::RoundRobinBatch(size), + Partitioning::UnknownPartitioning(size) => { + Partitioning::UnknownPartitioning(size) + } + Partitioning::Hash(exprs, size) => { + let new_exprs = exprs + .into_iter() + .map(|expr| { + expr.transform_down(&|e| { + Ok(e.as_any().downcast_ref::().map(|col| { + Arc::new(Column::new( + col.name(), + window_expr_len + col.index(), + )) + as Arc + })) + }) + .unwrap() + }) + .collect::>(); + Partitioning::Hash(new_exprs, size) + } + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.input.output_ordering() + self.output_ordering.as_deref() } fn maintains_input_order(&self) -> bool { @@ -146,7 +203,30 @@ impl ExecutionPlan for WindowAggExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + // Although WindowAggExec does not change the equivalence properties from the input, but can not return the equivalence properties + // from the input directly, need to adjust the column index to align with the new schema. + let window_expr_len = self.window_expr.len(); + let mut new_properties = EquivalenceProperties::new(self.schema()); + let new_eq_classes = self + .input + .equivalence_properties() + .classes() + .iter() + .map(|prop| { + let new_head = Column::new( + prop.head().name(), + window_expr_len + prop.head().index(), + ); + let new_others = prop + .others() + .iter() + .map(|col| Column::new(col.name(), window_expr_len + col.index())) + .collect::>(); + EquivalentClass::new(new_head, new_others) + }) + .collect::>(); + new_properties.extend(new_eq_classes); + new_properties } fn with_new_children( diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 95d0ed92944c..6d30d53f5a6b 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1607,3 +1607,40 @@ async fn test_window_frame_nth_value_aggregate() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9) as sum1, + SUM(c9) OVER(ORDER BY c9, c8) as sum2 + FROM aggregate_test_100"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST]@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST]@1 as sum2]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })]", + " SortExec: [c9@1 ASC NULLS LAST,c8@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) +}