@@ -55,7 +55,30 @@ use datafusion::arrow::error::ArrowError;
55
55
use datafusion:: execution:: context:: TaskContext ;
56
56
use datafusion:: physical_plan:: repartition:: BatchPartitioner ;
57
57
use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
58
+ use lazy_static:: lazy_static;
58
59
use log:: { debug, info} ;
60
+ use lru:: LruCache ;
61
+ use parking_lot:: Mutex ;
62
+ use std:: num:: NonZeroUsize ;
63
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
64
+
65
+ lazy_static ! {
66
+ static ref LIMIT_ACCUMULATORS : Mutex <LruCache <( String , usize ) , Arc <AtomicUsize >>> =
67
+ Mutex :: new( LruCache :: new( NonZeroUsize :: new( 40 ) . unwrap( ) ) ) ;
68
+ }
69
+
70
+ fn get_limit_accumulator ( job_id : & str , stage : usize ) -> Arc < AtomicUsize > {
71
+ let mut guard = LIMIT_ACCUMULATORS . lock ( ) ;
72
+
73
+ if let Some ( accumulator) = guard. get ( & ( job_id. to_owned ( ) , stage) ) {
74
+ accumulator. clone ( )
75
+ } else {
76
+ let accumulator = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
77
+ guard. push ( ( job_id. to_owned ( ) , stage) , accumulator. clone ( ) ) ;
78
+
79
+ accumulator
80
+ }
81
+ }
59
82
60
83
/// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and
61
84
/// can be executed as one unit with each partition being executed in parallel. The output of each
@@ -75,6 +98,8 @@ pub struct ShuffleWriterExec {
75
98
shuffle_output_partitioning : Option < Partitioning > ,
76
99
/// Execution metrics
77
100
metrics : ExecutionPlanMetricsSet ,
101
+ /// Maximum number of rows to return
102
+ limit : Option < usize > ,
78
103
}
79
104
80
105
#[ derive( Debug , Clone ) ]
@@ -121,6 +146,26 @@ impl ShuffleWriterExec {
121
146
work_dir,
122
147
shuffle_output_partitioning,
123
148
metrics : ExecutionPlanMetricsSet :: new ( ) ,
149
+ limit : None ,
150
+ } )
151
+ }
152
+
153
+ pub fn try_new_with_limit (
154
+ job_id : String ,
155
+ stage_id : usize ,
156
+ plan : Arc < dyn ExecutionPlan > ,
157
+ work_dir : String ,
158
+ shuffle_output_partitioning : Option < Partitioning > ,
159
+ limit : Option < usize > ,
160
+ ) -> Result < Self > {
161
+ Ok ( Self {
162
+ job_id,
163
+ stage_id,
164
+ plan,
165
+ work_dir,
166
+ shuffle_output_partitioning,
167
+ metrics : ExecutionPlanMetricsSet :: new ( ) ,
168
+ limit,
124
169
} )
125
170
}
126
171
@@ -139,6 +184,10 @@ impl ShuffleWriterExec {
139
184
self . shuffle_output_partitioning . as_ref ( )
140
185
}
141
186
187
+ pub fn limit ( & self ) -> Option < usize > {
188
+ self . limit
189
+ }
190
+
142
191
pub fn execute_shuffle_write (
143
192
& self ,
144
193
input_partition : usize ,
@@ -152,6 +201,10 @@ impl ShuffleWriterExec {
152
201
let output_partitioning = self . shuffle_output_partitioning . clone ( ) ;
153
202
let plan = self . plan . clone ( ) ;
154
203
204
+ let limit_and_accumulator = self
205
+ . limit
206
+ . map ( |l| ( l, get_limit_accumulator ( & self . job_id , self . stage_id ) ) ) ;
207
+
155
208
async move {
156
209
let now = Instant :: now ( ) ;
157
210
let mut stream = plan. execute ( input_partition, context) ?;
@@ -170,6 +223,7 @@ impl ShuffleWriterExec {
170
223
& mut stream,
171
224
path,
172
225
& write_metrics. write_time ,
226
+ limit_and_accumulator,
173
227
)
174
228
. await
175
229
. map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
@@ -211,10 +265,26 @@ impl ShuffleWriterExec {
211
265
write_metrics. repart_time . clone ( ) ,
212
266
) ?;
213
267
214
- while let Some ( result) = stream. next ( ) . await {
268
+ while let Some ( result) = {
269
+ let poll_more = limit_and_accumulator. as_ref ( ) . map_or (
270
+ true ,
271
+ |( limit, accum) | {
272
+ let total_rows = accum. load ( Ordering :: SeqCst ) ;
273
+ total_rows < * limit
274
+ } ,
275
+ ) ;
276
+
277
+ if poll_more {
278
+ stream. next ( ) . await
279
+ } else {
280
+ None
281
+ }
282
+ } {
215
283
let input_batch = result?;
216
284
217
- write_metrics. input_rows . add ( input_batch. num_rows ( ) ) ;
285
+ let num_rows = input_batch. num_rows ( ) ;
286
+
287
+ write_metrics. input_rows . add ( num_rows) ;
218
288
219
289
partitioner. partition (
220
290
input_batch,
@@ -252,6 +322,13 @@ impl ShuffleWriterExec {
252
322
Ok ( ( ) )
253
323
} ,
254
324
) ?;
325
+
326
+ if let Some ( ( limit, accum) ) = limit_and_accumulator. as_ref ( ) {
327
+ let total_rows = accum. fetch_add ( num_rows, Ordering :: SeqCst ) ;
328
+ if total_rows > * limit {
329
+ break ;
330
+ }
331
+ }
255
332
}
256
333
257
334
let mut part_locs = vec ! [ ] ;
@@ -320,12 +397,13 @@ impl ExecutionPlan for ShuffleWriterExec {
320
397
self : Arc < Self > ,
321
398
children : Vec < Arc < dyn ExecutionPlan > > ,
322
399
) -> Result < Arc < dyn ExecutionPlan > > {
323
- Ok ( Arc :: new ( ShuffleWriterExec :: try_new (
400
+ Ok ( Arc :: new ( ShuffleWriterExec :: try_new_with_limit (
324
401
self . job_id . clone ( ) ,
325
402
self . stage_id ,
326
403
children[ 0 ] . clone ( ) ,
327
404
self . work_dir . clone ( ) ,
328
405
self . shuffle_output_partitioning . clone ( ) ,
406
+ self . limit ,
329
407
) ?) )
330
408
}
331
409
0 commit comments