18
18
use chrono:: { TimeZone , Utc } ;
19
19
use datafusion:: common:: tree_node:: { Transformed , TreeNode } ;
20
20
use datafusion:: execution:: runtime_env:: RuntimeEnv ;
21
- use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF } ;
21
+ use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF , WindowUDF } ;
22
22
use datafusion:: physical_plan:: metrics:: {
23
23
Count , Gauge , MetricValue , MetricsSet , Time , Timestamp ,
24
24
} ;
@@ -279,6 +279,7 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
279
279
runtime : Arc < RuntimeEnv > ,
280
280
scalar_functions : HashMap < String , Arc < ScalarUDF > > ,
281
281
aggregate_functions : HashMap < String , Arc < AggregateUDF > > ,
282
+ window_functions : HashMap < String , Arc < WindowUDF > > ,
282
283
codec : BallistaCodec < T , U > ,
283
284
) -> Result < TaskDefinition , BallistaError > {
284
285
let mut props = HashMap :: new ( ) ;
@@ -289,16 +290,21 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
289
290
290
291
let mut task_scalar_functions = HashMap :: new ( ) ;
291
292
let mut task_aggregate_functions = HashMap :: new ( ) ;
293
+ let mut task_window_functions = HashMap :: new ( ) ;
292
294
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
293
295
for scalar_func in scalar_functions {
294
296
task_scalar_functions. insert ( scalar_func. 0 , scalar_func. 1 ) ;
295
297
}
296
298
for agg_func in aggregate_functions {
297
299
task_aggregate_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
298
300
}
301
+ for agg_func in window_functions {
302
+ task_window_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
303
+ }
299
304
let function_registry = Arc :: new ( SimpleFunctionRegistry {
300
305
scalar_functions : task_scalar_functions,
301
306
aggregate_functions : task_aggregate_functions,
307
+ window_functions : task_window_functions,
302
308
} ) ;
303
309
304
310
let encoded_plan = task. plan . as_slice ( ) ;
@@ -342,6 +348,7 @@ pub fn get_task_definition_vec<
342
348
runtime : Arc < RuntimeEnv > ,
343
349
scalar_functions : HashMap < String , Arc < ScalarUDF > > ,
344
350
aggregate_functions : HashMap < String , Arc < AggregateUDF > > ,
351
+ window_functions : HashMap < String , Arc < WindowUDF > > ,
345
352
codec : BallistaCodec < T , U > ,
346
353
) -> Result < Vec < TaskDefinition > , BallistaError > {
347
354
let mut props = HashMap :: new ( ) ;
@@ -352,16 +359,21 @@ pub fn get_task_definition_vec<
352
359
353
360
let mut task_scalar_functions = HashMap :: new ( ) ;
354
361
let mut task_aggregate_functions = HashMap :: new ( ) ;
362
+ let mut task_window_functions = HashMap :: new ( ) ;
355
363
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
356
364
for scalar_func in scalar_functions {
357
365
task_scalar_functions. insert ( scalar_func. 0 , scalar_func. 1 ) ;
358
366
}
359
367
for agg_func in aggregate_functions {
360
368
task_aggregate_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
361
369
}
370
+ for agg_func in window_functions {
371
+ task_window_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
372
+ }
362
373
let function_registry = Arc :: new ( SimpleFunctionRegistry {
363
374
scalar_functions : task_scalar_functions,
364
375
aggregate_functions : task_aggregate_functions,
376
+ window_functions : task_window_functions,
365
377
} ) ;
366
378
367
379
let encoded_plan = multi_task. plan . as_slice ( ) ;
0 commit comments