Skip to content

Commit

Permalink
Add fetch to SortPreservingMergeExec and SortPreservingMergeStream (
Browse files Browse the repository at this point in the history
apache#6811)

* Add fetch to sortpreservingmergeexec

* Add fetch to sortpreservingmergeexec

* fmt

* Deserialize

* Fmt

* Fix test

* Fix test

* Fix test

* Fix plan output

* Doc

* Update datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Extract into method

* Remove from sort enforcement

* Update datafusion/core/src/physical_plan/sorts/merge.rs

Co-authored-by: Mustafa Akur <[email protected]>

* Update datafusion/proto/src/physical_plan/mod.rs

Co-authored-by: Mustafa Akur <[email protected]>

---------

Co-authored-by: Daniël Heres <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
4 people authored and 2010YOUY01 committed Jul 5, 2023
1 parent 80b1ff6 commit faa56d2
Show file tree
Hide file tree
Showing 19 changed files with 103 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
Arc::new(SortPreservingMergeExec::new(
sort_exec.expr().to_vec(),
Arc::new(sort),
));
).with_fetch(sort_exec.fetch()));
Some(global_sort)
} else {
None
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/physical_plan/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
sort_exprs,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
None,
)
} else {
Ok(Box::pin(RepartitionStream {
Expand Down
38 changes: 31 additions & 7 deletions datafusion/core/src/physical_plan/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
}

macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
$fetch,
)));
}};
}
Expand All @@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
fetch: Option<usize>,
) -> Result<SendableRecordBatchStream> {
// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch)
_ => {}
}
}
Expand All @@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
schema,
metrics,
batch_size,
fetch,
)))
}

Expand Down Expand Up @@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {

/// Vector that holds cursors for each non-exhausted input partition
cursors: Vec<Option<C>>,

/// Optional number of rows to fetch
fetch: Option<usize>,

/// number of rows produced
produced: usize,
}

impl<C: Cursor> SortPreservingMergeStream<C> {
Expand All @@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
schema: SchemaRef,
metrics: BaselineMetrics,
batch_size: usize,
fetch: Option<usize>,
) -> Self {
let stream_count = streams.partitions();

Expand All @@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
loser_tree: vec![],
loser_tree_adjusted: false,
batch_size,
fetch,
produced: 0,
}
}

Expand Down Expand Up @@ -227,15 +239,27 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
if self.advance(stream_idx) {
self.loser_tree_adjusted = false;
self.in_progress.push_row(stream_idx);
if self.in_progress.len() < self.batch_size {

// stop sorting if fetch has been reached
if self.fetch_reached() {
self.aborted = true;
} else if self.in_progress.len() < self.batch_size {
continue;
}
}

self.produced += self.in_progress.len();

return Poll::Ready(self.in_progress.build_record_batch().transpose());
}
}

fn fetch_reached(&mut self) -> bool {
self.fetch
.map(|fetch| self.produced + self.in_progress.len() >= fetch)
.unwrap_or(false)
}

fn advance(&mut self, stream_idx: usize) -> bool {
let slot = &mut self.cursors[stream_idx];
match slot.as_mut() {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl ExternalSorter {
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
self.fetch,
)
} else if !self.in_mem_batches.is_empty() {
let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
Expand Down Expand Up @@ -285,14 +286,13 @@ impl ExternalSorter {
})
.collect::<Result<_>>()?;

// TODO: Pushdown fetch to streaming merge (#6000)

streaming_merge(
streams,
self.schema.clone(),
&self.expr,
metrics,
self.batch_size,
self.fetch,
)
}

Expand Down
30 changes: 25 additions & 5 deletions datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
expr: Vec<PhysicalSortExpr>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Optional number of rows to fetch. Stops producing rows after this fetch
fetch: Option<usize>,
}

impl SortPreservingMergeExec {
Expand All @@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
input,
expr,
metrics: ExecutionPlanMetricsSet::new(),
fetch: None,
}
}
/// Sets the number of rows to fetch
pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}

/// Input schema
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
Expand All @@ -92,6 +100,11 @@ impl SortPreservingMergeExec {
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}

/// Fetch
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
}

impl ExecutionPlan for SortPreservingMergeExec {
Expand Down Expand Up @@ -137,10 +150,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(SortPreservingMergeExec::new(
self.expr.clone(),
children[0].clone(),
)))
Ok(Arc::new(
SortPreservingMergeExec::new(self.expr.clone(), children[0].clone())
.with_fetch(self.fetch),
))
}

fn execute(
Expand Down Expand Up @@ -192,6 +205,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
self.fetch,
)?;

debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
Expand All @@ -209,7 +223,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
if let Some(fetch) = self.fetch {
write!(f, ", fetch={fetch}")?;
};

Ok(())
}
}
}
Expand Down Expand Up @@ -814,6 +833,7 @@ mod tests {
sort.as_slice(),
BaselineMetrics::new(&metrics, 0),
task_ctx.session_config().batch_size(),
None,
)
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ async fn test_physical_plan_display_indent() {
let physical_plan = dataframe.create_physical_plan().await.unwrap();
let expected = vec![
"GlobalLimitExec: skip=0, fetch=10",
" SortPreservingMergeExec: [the_min@2 DESC]",
" SortPreservingMergeExec: [the_min@2 DESC], fetch=10",
" SortExec: fetch=10, expr=[the_min@2 DESC]",
" ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]",
" AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Limit: skip=0, fetch=10
------------TableScan: nation projection=[n_nationkey, n_name]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [revenue@2 DESC]
--SortPreservingMergeExec: [revenue@2 DESC], fetch=10
----SortExec: fetch=10, expr=[revenue@2 DESC]
------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment]
--------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Limit: skip=0, fetch=10
----------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [value@1 DESC]
--SortPreservingMergeExec: [value@1 DESC], fetch=10
----SortExec: fetch=10, expr=[value@1 DESC]
------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value]
--------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Limit: skip=0, fetch=10
------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC]
--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10
----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC]
------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as custdist]
--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(UInt8(1))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Limit: skip=0, fetch=10
------------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10
----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt]
--------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Limit: skip=0, fetch=10
----------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10
----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment]
--------CoalesceBatchesExec: target_batch_size=8192
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Limit: skip=0, fetch=10
----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10
----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority]
--------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Limit: skip=0, fetch=10
--------------TableScan: nation projection=[n_nationkey, n_name]
physical_plan
GlobalLimitExec: skip=0, fetch=10
--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC]
--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10
----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC]
------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit]
--------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sqllogictests/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ Limit: skip=0, fetch=5
--------TableScan: aggregate_test_100 projection=[c1, c3]
physical_plan
GlobalLimitExec: skip=0, fetch=5
--SortPreservingMergeExec: [c9@1 DESC]
--SortPreservingMergeExec: [c9@1 DESC], fetch=5
----UnionExec
------SortExec: expr=[c9@1 DESC]
--------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sqllogictests/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ Limit: skip=0, fetch=5
------------TableScan: aggregate_test_100 projection=[c2, c3, c9]
physical_plan
GlobalLimitExec: skip=0, fetch=5
--SortPreservingMergeExec: [c3@0 ASC NULLS LAST]
--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5
----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2]
------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]
--------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC]
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,8 @@ message SortExecNode {
message SortPreservingMergeExecNode {
PhysicalPlanNode input = 1;
repeated PhysicalExprNode expr = 2;
// Maximum number of highest/lowest rows to fetch; negative means no limit
int64 fetch = 3;
}

message CoalesceBatchesExecNode {
Expand Down
19 changes: 19 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit faa56d2

Please sign in to comment.