16
16
// under the License.
17
17
18
18
use chrono:: { TimeZone , Utc } ;
19
+ use datafusion:: common:: tree_node:: { Transformed , TreeNode } ;
20
+ use datafusion:: execution:: runtime_env:: RuntimeEnv ;
21
+ use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF } ;
19
22
use datafusion:: physical_plan:: metrics:: {
20
23
Count , Gauge , MetricValue , MetricsSet , Time , Timestamp ,
21
24
} ;
22
- use datafusion:: physical_plan:: Metric ;
25
+ use datafusion:: physical_plan:: { ExecutionPlan , Metric } ;
26
+ use datafusion_proto:: logical_plan:: AsLogicalPlan ;
27
+ use datafusion_proto:: physical_plan:: AsExecutionPlan ;
23
28
use std:: collections:: HashMap ;
24
29
use std:: convert:: TryInto ;
25
30
use std:: sync:: Arc ;
@@ -28,10 +33,10 @@ use std::time::Duration;
28
33
use crate :: error:: BallistaError ;
29
34
use crate :: serde:: scheduler:: {
30
35
Action , ExecutorData , ExecutorMetadata , ExecutorSpecification , PartitionId ,
31
- PartitionLocation , PartitionStats , TaskDefinition ,
36
+ PartitionLocation , PartitionStats , SimpleFunctionRegistry , TaskDefinition ,
32
37
} ;
33
38
34
- use crate :: serde:: protobuf;
39
+ use crate :: serde:: { protobuf, BallistaCodec } ;
35
40
use protobuf:: { operator_metric, NamedCount , NamedGauge , NamedTime } ;
36
41
37
42
impl TryInto < Action > for protobuf:: Action {
@@ -269,67 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
269
274
}
270
275
}
271
276
272
- impl TryInto < ( TaskDefinition , Vec < u8 > ) > for protobuf:: TaskDefinition {
273
- type Error = BallistaError ;
274
-
275
- fn try_into ( self ) -> Result < ( TaskDefinition , Vec < u8 > ) , Self :: Error > {
276
- let mut props = HashMap :: new ( ) ;
277
- for kv_pair in self . props {
278
- props. insert ( kv_pair. key , kv_pair. value ) ;
279
- }
277
+ pub fn get_task_definition < T : ' static + AsLogicalPlan , U : ' static + AsExecutionPlan > (
278
+ task : protobuf:: TaskDefinition ,
279
+ runtime : Arc < RuntimeEnv > ,
280
+ scalar_functions : HashMap < String , Arc < ScalarUDF > > ,
281
+ aggregate_functions : HashMap < String , Arc < AggregateUDF > > ,
282
+ codec : BallistaCodec < T , U > ,
283
+ ) -> Result < TaskDefinition , BallistaError > {
284
+ let mut props = HashMap :: new ( ) ;
285
+ for kv_pair in task. props {
286
+ props. insert ( kv_pair. key , kv_pair. value ) ;
287
+ }
288
+ let props = Arc :: new ( props) ;
280
289
281
- Ok ( (
282
- TaskDefinition {
283
- task_id : self . task_id as usize ,
284
- task_attempt_num : self . task_attempt_num as usize ,
285
- job_id : self . job_id ,
286
- stage_id : self . stage_id as usize ,
287
- stage_attempt_num : self . stage_attempt_num as usize ,
288
- partition_id : self . partition_id as usize ,
289
- plan : vec ! [ ] ,
290
- session_id : self . session_id ,
291
- launch_time : self . launch_time ,
292
- props,
293
- } ,
294
- self . plan ,
295
- ) )
290
+ let mut task_scalar_functions = HashMap :: new ( ) ;
291
+ let mut task_aggregate_functions = HashMap :: new ( ) ;
292
+ // TODO combine the functions from Executor's functions and TaskDefinition's function resources
293
+ for scalar_func in scalar_functions {
294
+ task_scalar_functions. insert ( scalar_func. 0 , scalar_func. 1 ) ;
296
295
}
297
- }
296
+ for agg_func in aggregate_functions {
297
+ task_aggregate_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
298
+ }
299
+ let function_registry = Arc :: new ( SimpleFunctionRegistry {
300
+ scalar_functions : task_scalar_functions,
301
+ aggregate_functions : task_aggregate_functions,
302
+ } ) ;
298
303
299
- impl TryInto < ( Vec < TaskDefinition > , Vec < u8 > ) > for protobuf:: MultiTaskDefinition {
300
- type Error = BallistaError ;
304
+ let encoded_plan = task. plan . as_slice ( ) ;
305
+ let plan: Arc < dyn ExecutionPlan > = U :: try_decode ( encoded_plan) . and_then ( |proto| {
306
+ proto. try_into_physical_plan (
307
+ function_registry. as_ref ( ) ,
308
+ runtime. as_ref ( ) ,
309
+ codec. physical_extension_codec ( ) ,
310
+ )
311
+ } ) ?;
301
312
302
- fn try_into ( self ) -> Result < ( Vec < TaskDefinition > , Vec < u8 > ) , Self :: Error > {
303
- let mut props = HashMap :: new ( ) ;
304
- for kv_pair in self . props {
305
- props. insert ( kv_pair. key , kv_pair. value ) ;
306
- }
313
+ let job_id = task. job_id ;
314
+ let stage_id = task. stage_id as usize ;
315
+ let partition_id = task. partition_id as usize ;
316
+ let task_attempt_num = task. task_attempt_num as usize ;
317
+ let stage_attempt_num = task. stage_attempt_num as usize ;
318
+ let launch_time = task. launch_time ;
319
+ let task_id = task. task_id as usize ;
320
+ let session_id = task. session_id ;
307
321
308
- let plan = self . plan ;
309
- let session_id = self . session_id ;
310
- let job_id = self . job_id ;
311
- let stage_id = self . stage_id as usize ;
312
- let stage_attempt_num = self . stage_attempt_num as usize ;
313
- let launch_time = self . launch_time ;
314
- let task_ids = self . task_ids ;
322
+ Ok ( TaskDefinition {
323
+ task_id,
324
+ task_attempt_num,
325
+ job_id,
326
+ stage_id,
327
+ stage_attempt_num,
328
+ partition_id,
329
+ plan,
330
+ launch_time,
331
+ session_id,
332
+ props,
333
+ function_registry,
334
+ } )
335
+ }
315
336
316
- Ok ( (
317
- task_ids
318
- . iter ( )
319
- . map ( |task_id| TaskDefinition {
320
- task_id : task_id. task_id as usize ,
321
- task_attempt_num : task_id. task_attempt_num as usize ,
322
- job_id : job_id. clone ( ) ,
323
- stage_id,
324
- stage_attempt_num,
325
- partition_id : task_id. partition_id as usize ,
326
- plan : vec ! [ ] ,
327
- session_id : session_id. clone ( ) ,
328
- launch_time,
329
- props : props. clone ( ) ,
330
- } )
331
- . collect ( ) ,
332
- plan,
333
- ) )
337
+ pub fn get_task_definition_vec <
338
+ T : ' static + AsLogicalPlan ,
339
+ U : ' static + AsExecutionPlan ,
340
+ > (
341
+ multi_task : protobuf:: MultiTaskDefinition ,
342
+ runtime : Arc < RuntimeEnv > ,
343
+ scalar_functions : HashMap < String , Arc < ScalarUDF > > ,
344
+ aggregate_functions : HashMap < String , Arc < AggregateUDF > > ,
345
+ codec : BallistaCodec < T , U > ,
346
+ ) -> Result < Vec < TaskDefinition > , BallistaError > {
347
+ let mut props = HashMap :: new ( ) ;
348
+ for kv_pair in multi_task. props {
349
+ props. insert ( kv_pair. key , kv_pair. value ) ;
334
350
}
351
+ let props = Arc :: new ( props) ;
352
+
353
+ let mut task_scalar_functions = HashMap :: new ( ) ;
354
+ let mut task_aggregate_functions = HashMap :: new ( ) ;
355
+ // TODO combine the functions from Executor's functions and TaskDefinition's function resources
356
+ for scalar_func in scalar_functions {
357
+ task_scalar_functions. insert ( scalar_func. 0 , scalar_func. 1 ) ;
358
+ }
359
+ for agg_func in aggregate_functions {
360
+ task_aggregate_functions. insert ( agg_func. 0 , agg_func. 1 ) ;
361
+ }
362
+ let function_registry = Arc :: new ( SimpleFunctionRegistry {
363
+ scalar_functions : task_scalar_functions,
364
+ aggregate_functions : task_aggregate_functions,
365
+ } ) ;
366
+
367
+ let encoded_plan = multi_task. plan . as_slice ( ) ;
368
+ let plan: Arc < dyn ExecutionPlan > = U :: try_decode ( encoded_plan) . and_then ( |proto| {
369
+ proto. try_into_physical_plan (
370
+ function_registry. as_ref ( ) ,
371
+ runtime. as_ref ( ) ,
372
+ codec. physical_extension_codec ( ) ,
373
+ )
374
+ } ) ?;
375
+
376
+ let job_id = multi_task. job_id ;
377
+ let stage_id = multi_task. stage_id as usize ;
378
+ let stage_attempt_num = multi_task. stage_attempt_num as usize ;
379
+ let launch_time = multi_task. launch_time ;
380
+ let task_ids = multi_task. task_ids ;
381
+ let session_id = multi_task. session_id ;
382
+
383
+ task_ids
384
+ . iter ( )
385
+ . map ( |task_id| {
386
+ Ok ( TaskDefinition {
387
+ task_id : task_id. task_id as usize ,
388
+ task_attempt_num : task_id. task_attempt_num as usize ,
389
+ job_id : job_id. clone ( ) ,
390
+ stage_id,
391
+ stage_attempt_num,
392
+ partition_id : task_id. partition_id as usize ,
393
+ plan : reset_metrics_for_execution_plan ( plan. clone ( ) ) ?,
394
+ launch_time,
395
+ session_id : session_id. clone ( ) ,
396
+ props : props. clone ( ) ,
397
+ function_registry : function_registry. clone ( ) ,
398
+ } )
399
+ } )
400
+ . collect ( )
401
+ }
402
+
403
+ fn reset_metrics_for_execution_plan (
404
+ plan : Arc < dyn ExecutionPlan > ,
405
+ ) -> Result < Arc < dyn ExecutionPlan > , BallistaError > {
406
+ plan. transform ( & |plan| {
407
+ let children = plan. children ( ) . clone ( ) ;
408
+ plan. with_new_children ( children) . map ( Transformed :: Yes )
409
+ } )
410
+ . map_err ( BallistaError :: DataFusionError )
335
411
}
0 commit comments