diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs index 9466297d24d00..0b9054f89ff4c 100644 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs @@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection { Arc::new(SortPreservingMergeExec::new( sort_exec.expr().to_vec(), Arc::new(sort), - )); + ).with_fetch(sort_exec.fetch())); Some(global_sort) } else { None diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs index 72ff0c37135b3..85225eb471760 100644 --- a/datafusion/core/src/physical_plan/repartition/mod.rs +++ b/datafusion/core/src/physical_plan/repartition/mod.rs @@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec { sort_exprs, BaselineMetrics::new(&self.metrics, partition), context.session_config().batch_size(), + None, ) } else { Ok(Box::pin(RepartitionStream { diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs index d8a3cdef4d686..e191c044b9040 100644 --- a/datafusion/core/src/physical_plan/sorts/merge.rs +++ b/datafusion/core/src/physical_plan/sorts/merge.rs @@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper { } macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{ + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{ let streams = FieldCursorStream::<$t>::new($sort, $streams); return Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), $schema, $tracking_metrics, $batch_size, + $fetch, ))); }}; } @@ -57,17 +58,18 @@ pub(crate) fn streaming_merge( expressions: &[PhysicalSortExpr], metrics: BaselineMetrics, batch_size: usize, + fetch: Option, ) -> Result { // Special case single column comparisons with optimized cursor implementations if expressions.len() == 1 { let sort = expressions[0].clone(); let data_type = sort.expr.data_type(schema.as_ref())?; downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size) + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch) _ => {} } } @@ -78,6 +80,7 @@ pub(crate) fn streaming_merge( schema, metrics, batch_size, + fetch, ))) } @@ -140,6 +143,12 @@ struct SortPreservingMergeStream { /// Vector that holds cursors for each non-exhausted input partition cursors: Vec>, + + /// Optional number of rows to fetch + fetch: Option, + + /// number of rows produced + produced: usize, } impl SortPreservingMergeStream { @@ -148,6 +157,7 @@ impl SortPreservingMergeStream { schema: SchemaRef, metrics: BaselineMetrics, batch_size: usize, + fetch: Option, ) -> Self { let stream_count = streams.partitions(); @@ -160,6 +170,8 @@ impl SortPreservingMergeStream { loser_tree: vec![], loser_tree_adjusted: false, batch_size, + fetch, + produced: 0, } } @@ -227,15 +239,27 @@ impl SortPreservingMergeStream { if self.advance(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); - if self.in_progress.len() < self.batch_size { + + // stop sorting if fetch has been reached + if self.fetch_reached() { + self.aborted = true; + } else if self.in_progress.len() < self.batch_size { continue; } } + self.produced += self.in_progress.len(); + return Poll::Ready(self.in_progress.build_record_batch().transpose()); } } + fn fetch_reached(&mut self) -> bool { + self.fetch + .map(|fetch| self.produced + self.in_progress.len() >= fetch) + .unwrap_or(false) + } + fn advance(&mut self, stream_idx: usize) -> bool { let slot = &mut self.cursors[stream_idx]; match slot.as_mut() { diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 4983b0ea83e5e..205ec706b5dcd 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -189,6 +189,7 @@ impl ExternalSorter { &self.expr, self.metrics.baseline.clone(), self.batch_size, + self.fetch, ) } else if !self.in_mem_batches.is_empty() { let result = self.in_mem_sort_stream(self.metrics.baseline.clone()); @@ -285,14 +286,13 @@ impl ExternalSorter { }) .collect::>()?; - // TODO: Pushdown fetch to streaming merge (#6000) - streaming_merge( streams, self.schema.clone(), &self.expr, metrics, self.batch_size, + self.fetch, ) } diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 4db1fea2a4f1e..397d254162c72 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec { expr: Vec, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Optional number of rows to fetch. Stops producing rows after this fetch + fetch: Option, } impl SortPreservingMergeExec { @@ -80,8 +82,14 @@ impl SortPreservingMergeExec { input, expr, metrics: ExecutionPlanMetricsSet::new(), + fetch: None, } } + /// Sets the number of rows to fetch + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } /// Input schema pub fn input(&self) -> &Arc { @@ -92,6 +100,11 @@ impl SortPreservingMergeExec { pub fn expr(&self) -> &[PhysicalSortExpr] { &self.expr } + + /// Fetch + pub fn fetch(&self) -> Option { + self.fetch + } } impl ExecutionPlan for SortPreservingMergeExec { @@ -137,10 +150,10 @@ impl ExecutionPlan for SortPreservingMergeExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(SortPreservingMergeExec::new( - self.expr.clone(), - children[0].clone(), - ))) + Ok(Arc::new( + SortPreservingMergeExec::new(self.expr.clone(), children[0].clone()) + .with_fetch(self.fetch), + )) } fn execute( @@ -192,6 +205,7 @@ impl ExecutionPlan for SortPreservingMergeExec { &self.expr, BaselineMetrics::new(&self.metrics, partition), context.session_config().batch_size(), + self.fetch, )?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -209,7 +223,12 @@ impl ExecutionPlan for SortPreservingMergeExec { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortPreservingMergeExec: [{}]", expr.join(",")) + write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?; + if let Some(fetch) = self.fetch { + write!(f, ", fetch={fetch}")?; + }; + + Ok(()) } } } @@ -814,6 +833,7 @@ mod tests { sort.as_slice(), BaselineMetrics::new(&metrics, 0), task_ctx.session_config().batch_size(), + None, ) .unwrap(); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 01bdb629ee840..e0130cb09c8c0 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -599,7 +599,7 @@ async fn test_physical_plan_display_indent() { let physical_plan = dataframe.create_physical_plan().await.unwrap(); let expected = vec![ "GlobalLimitExec: skip=0, fetch=10", - " SortPreservingMergeExec: [the_min@2 DESC]", + " SortPreservingMergeExec: [the_min@2 DESC], fetch=10", " SortExec: fetch=10, expr=[the_min@2 DESC]", " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part index b74dca0272fab..d46536a253492 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part @@ -71,7 +71,7 @@ Limit: skip=0, fetch=10 ------------TableScan: nation projection=[n_nationkey, n_name] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [revenue@2 DESC] +--SortPreservingMergeExec: [revenue@2 DESC], fetch=10 ----SortExec: fetch=10, expr=[revenue@2 DESC] ------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment] --------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part index 7429a7e216ec0..9118935c4b73e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part @@ -75,7 +75,7 @@ Limit: skip=0, fetch=10 ----------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [value@1 DESC] +--SortPreservingMergeExec: [value@1 DESC], fetch=10 ----SortExec: fetch=10, expr=[value@1 DESC] ------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] --------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part index 1d35c0db58ca4..bd358962b5e1b 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part @@ -56,7 +56,7 @@ Limit: skip=0, fetch=10 ------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC] +--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10 ----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC] ------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as custdist] --------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(UInt8(1))] diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part index 02dc4ae5503fa..1f24791a56933 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part @@ -67,7 +67,7 @@ Limit: skip=0, fetch=10 ------------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] +--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 ----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] ------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt] --------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part index e5c7a54ecfe03..3ad63f482a3a8 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part @@ -101,7 +101,7 @@ Limit: skip=0, fetch=10 ----------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] +--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10 ----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] ------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] --------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part index e573d761ede5d..91af8b77996db 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part @@ -60,7 +60,7 @@ Limit: skip=0, fetch=10 ----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] +--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10 ----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] ------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority] --------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part index e6344a007ec94..2feaef32cf4be 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part +++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part @@ -77,7 +77,7 @@ Limit: skip=0, fetch=10 --------------TableScan: nation projection=[n_nationkey, n_name] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC] +--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10 ----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] ------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit] --------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt b/datafusion/core/tests/sqllogictests/test_files/union.slt index 94c9eef89324c..2b3022ddd1a35 100644 --- a/datafusion/core/tests/sqllogictests/test_files/union.slt +++ b/datafusion/core/tests/sqllogictests/test_files/union.slt @@ -308,7 +308,7 @@ Limit: skip=0, fetch=5 --------TableScan: aggregate_test_100 projection=[c1, c3] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortPreservingMergeExec: [c9@1 DESC] +--SortPreservingMergeExec: [c9@1 DESC], fetch=5 ----UnionExec ------SortExec: expr=[c9@1 DESC] --------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9] diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index 08d1a5616e8a2..d77df127a80a2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -1792,7 +1792,7 @@ Limit: skip=0, fetch=5 ------------TableScan: aggregate_test_100 projection=[c2, c3, c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortPreservingMergeExec: [c3@0 ASC NULLS LAST] +--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5 ----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC] diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index de334dc4a5cc7..0d61cd2b3573a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1366,6 +1366,8 @@ message SortExecNode { message SortPreservingMergeExecNode { PhysicalPlanNode input = 1; repeated PhysicalExprNode expr = 2; + // Maximum number of highest/lowest rows to fetch; negative means no limit + int64 fetch = 3; } message CoalesceBatchesExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 1cf08be321e1f..831dd49618f7d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode { if !self.expr.is_empty() { len += 1; } + if self.fetch != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode { if !self.expr.is_empty() { struct_ser.serialize_field("expr", &self.expr)?; } + if self.fetch != 0 { + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + } struct_ser.end() } } @@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { { let mut input__ = None; let mut expr__ = None; + let mut fetch__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::Input => { @@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { } expr__ = Some(map.next_value()?); } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } } } Ok(SortPreservingMergeExecNode { input: input__, expr: expr__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 5f201b124d1b9..e6c076e7d4538 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1926,6 +1926,9 @@ pub struct SortPreservingMergeExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] pub expr: ::prost::alloc::vec::Vec, + /// Maximum number of highest/lowest rows to fetch; negative means no limit + #[prost(int64, tag = "3")] + pub fetch: i64, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1daa1c2e4b9cb..7bbbe135680bc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -692,7 +692,14 @@ impl AsExecutionPlan for PhysicalPlanNode { } }) .collect::, _>>()?; - Ok(Arc::new(SortPreservingMergeExec::new(exprs, input))) + let fetch = if sort.fetch < 0 { + None + } else { + Some(sort.fetch as usize) + }; + Ok(Arc::new( + SortPreservingMergeExec::new(exprs, input).with_fetch(fetch), + )) } PhysicalPlanType::Extension(extension) => { let inputs: Vec> = extension @@ -1144,6 +1151,7 @@ impl AsExecutionPlan for PhysicalPlanNode { Box::new(protobuf::SortPreservingMergeExecNode { input: Some(Box::new(input)), expr, + fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1), }), )), })