Skip to content

Commit d7a808c

Browse files
Refactor the TaskDefinition by changing encoding execution plan to the decoded one (#817)
* Revert "Only decode plan in `LaunchMultiTaskParams` once (#743)" This reverts commit 4e4842c. * Refactor the TaskDefinition by changing encoding execution plan to the decoded one * Refine the error handling of run_task in the executor_server --------- Co-authored-by: yangzhong <[email protected]>
1 parent 553b9a7 commit d7a808c

File tree

5 files changed

+272
-280
lines changed

5 files changed

+272
-280
lines changed

ballista/core/src/serde/scheduler/from_proto.rs

+135-59
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
// under the License.
1717

1818
use chrono::{TimeZone, Utc};
19+
use datafusion::common::tree_node::{Transformed, TreeNode};
20+
use datafusion::execution::runtime_env::RuntimeEnv;
21+
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
1922
use datafusion::physical_plan::metrics::{
2023
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
2124
};
22-
use datafusion::physical_plan::Metric;
25+
use datafusion::physical_plan::{ExecutionPlan, Metric};
26+
use datafusion_proto::logical_plan::AsLogicalPlan;
27+
use datafusion_proto::physical_plan::AsExecutionPlan;
2328
use std::collections::HashMap;
2429
use std::convert::TryInto;
2530
use std::sync::Arc;
@@ -28,10 +33,10 @@ use std::time::Duration;
2833
use crate::error::BallistaError;
2934
use crate::serde::scheduler::{
3035
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
31-
PartitionLocation, PartitionStats, TaskDefinition,
36+
PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
3237
};
3338

34-
use crate::serde::protobuf;
39+
use crate::serde::{protobuf, BallistaCodec};
3540
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};
3641

3742
impl TryInto<Action> for protobuf::Action {
@@ -269,67 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
269274
}
270275
}
271276

272-
impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
273-
type Error = BallistaError;
274-
275-
fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
276-
let mut props = HashMap::new();
277-
for kv_pair in self.props {
278-
props.insert(kv_pair.key, kv_pair.value);
279-
}
277+
pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
278+
task: protobuf::TaskDefinition,
279+
runtime: Arc<RuntimeEnv>,
280+
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
281+
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
282+
codec: BallistaCodec<T, U>,
283+
) -> Result<TaskDefinition, BallistaError> {
284+
let mut props = HashMap::new();
285+
for kv_pair in task.props {
286+
props.insert(kv_pair.key, kv_pair.value);
287+
}
288+
let props = Arc::new(props);
280289

281-
Ok((
282-
TaskDefinition {
283-
task_id: self.task_id as usize,
284-
task_attempt_num: self.task_attempt_num as usize,
285-
job_id: self.job_id,
286-
stage_id: self.stage_id as usize,
287-
stage_attempt_num: self.stage_attempt_num as usize,
288-
partition_id: self.partition_id as usize,
289-
plan: vec![],
290-
session_id: self.session_id,
291-
launch_time: self.launch_time,
292-
props,
293-
},
294-
self.plan,
295-
))
290+
let mut task_scalar_functions = HashMap::new();
291+
let mut task_aggregate_functions = HashMap::new();
292+
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
293+
for scalar_func in scalar_functions {
294+
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
296295
}
297-
}
296+
for agg_func in aggregate_functions {
297+
task_aggregate_functions.insert(agg_func.0, agg_func.1);
298+
}
299+
let function_registry = Arc::new(SimpleFunctionRegistry {
300+
scalar_functions: task_scalar_functions,
301+
aggregate_functions: task_aggregate_functions,
302+
});
298303

299-
impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
300-
type Error = BallistaError;
304+
let encoded_plan = task.plan.as_slice();
305+
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
306+
proto.try_into_physical_plan(
307+
function_registry.as_ref(),
308+
runtime.as_ref(),
309+
codec.physical_extension_codec(),
310+
)
311+
})?;
301312

302-
fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
303-
let mut props = HashMap::new();
304-
for kv_pair in self.props {
305-
props.insert(kv_pair.key, kv_pair.value);
306-
}
313+
let job_id = task.job_id;
314+
let stage_id = task.stage_id as usize;
315+
let partition_id = task.partition_id as usize;
316+
let task_attempt_num = task.task_attempt_num as usize;
317+
let stage_attempt_num = task.stage_attempt_num as usize;
318+
let launch_time = task.launch_time;
319+
let task_id = task.task_id as usize;
320+
let session_id = task.session_id;
307321

308-
let plan = self.plan;
309-
let session_id = self.session_id;
310-
let job_id = self.job_id;
311-
let stage_id = self.stage_id as usize;
312-
let stage_attempt_num = self.stage_attempt_num as usize;
313-
let launch_time = self.launch_time;
314-
let task_ids = self.task_ids;
322+
Ok(TaskDefinition {
323+
task_id,
324+
task_attempt_num,
325+
job_id,
326+
stage_id,
327+
stage_attempt_num,
328+
partition_id,
329+
plan,
330+
launch_time,
331+
session_id,
332+
props,
333+
function_registry,
334+
})
335+
}
315336

316-
Ok((
317-
task_ids
318-
.iter()
319-
.map(|task_id| TaskDefinition {
320-
task_id: task_id.task_id as usize,
321-
task_attempt_num: task_id.task_attempt_num as usize,
322-
job_id: job_id.clone(),
323-
stage_id,
324-
stage_attempt_num,
325-
partition_id: task_id.partition_id as usize,
326-
plan: vec![],
327-
session_id: session_id.clone(),
328-
launch_time,
329-
props: props.clone(),
330-
})
331-
.collect(),
332-
plan,
333-
))
337+
pub fn get_task_definition_vec<
338+
T: 'static + AsLogicalPlan,
339+
U: 'static + AsExecutionPlan,
340+
>(
341+
multi_task: protobuf::MultiTaskDefinition,
342+
runtime: Arc<RuntimeEnv>,
343+
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
344+
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
345+
codec: BallistaCodec<T, U>,
346+
) -> Result<Vec<TaskDefinition>, BallistaError> {
347+
let mut props = HashMap::new();
348+
for kv_pair in multi_task.props {
349+
props.insert(kv_pair.key, kv_pair.value);
334350
}
351+
let props = Arc::new(props);
352+
353+
let mut task_scalar_functions = HashMap::new();
354+
let mut task_aggregate_functions = HashMap::new();
355+
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
356+
for scalar_func in scalar_functions {
357+
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
358+
}
359+
for agg_func in aggregate_functions {
360+
task_aggregate_functions.insert(agg_func.0, agg_func.1);
361+
}
362+
let function_registry = Arc::new(SimpleFunctionRegistry {
363+
scalar_functions: task_scalar_functions,
364+
aggregate_functions: task_aggregate_functions,
365+
});
366+
367+
let encoded_plan = multi_task.plan.as_slice();
368+
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
369+
proto.try_into_physical_plan(
370+
function_registry.as_ref(),
371+
runtime.as_ref(),
372+
codec.physical_extension_codec(),
373+
)
374+
})?;
375+
376+
let job_id = multi_task.job_id;
377+
let stage_id = multi_task.stage_id as usize;
378+
let stage_attempt_num = multi_task.stage_attempt_num as usize;
379+
let launch_time = multi_task.launch_time;
380+
let task_ids = multi_task.task_ids;
381+
let session_id = multi_task.session_id;
382+
383+
task_ids
384+
.iter()
385+
.map(|task_id| {
386+
Ok(TaskDefinition {
387+
task_id: task_id.task_id as usize,
388+
task_attempt_num: task_id.task_attempt_num as usize,
389+
job_id: job_id.clone(),
390+
stage_id,
391+
stage_attempt_num,
392+
partition_id: task_id.partition_id as usize,
393+
plan: reset_metrics_for_execution_plan(plan.clone())?,
394+
launch_time,
395+
session_id: session_id.clone(),
396+
props: props.clone(),
397+
function_registry: function_registry.clone(),
398+
})
399+
})
400+
.collect()
401+
}
402+
403+
fn reset_metrics_for_execution_plan(
404+
plan: Arc<dyn ExecutionPlan>,
405+
) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
406+
plan.transform(&|plan| {
407+
let children = plan.children().clone();
408+
plan.with_new_children(children).map(Transformed::Yes)
409+
})
410+
.map_err(BallistaError::DataFusionError)
335411
}

ballista/core/src/serde/scheduler/mod.rs

+42-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::collections::HashSet;
19+
use std::fmt::Debug;
1820
use std::{collections::HashMap, fmt, sync::Arc};
1921

2022
use datafusion::arrow::array::{
2123
ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
2224
};
2325
use datafusion::arrow::datatypes::{DataType, Field};
26+
use datafusion::common::DataFusionError;
27+
use datafusion::execution::FunctionRegistry;
28+
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
2429
use datafusion::physical_plan::ExecutionPlan;
2530
use datafusion::physical_plan::Partitioning;
2631
use serde::Serialize;
@@ -271,16 +276,49 @@ impl ExecutePartitionResult {
271276
}
272277
}
273278

274-
#[derive(Debug, Clone)]
279+
#[derive(Clone, Debug)]
275280
pub struct TaskDefinition {
276281
pub task_id: usize,
277282
pub task_attempt_num: usize,
278283
pub job_id: String,
279284
pub stage_id: usize,
280285
pub stage_attempt_num: usize,
281286
pub partition_id: usize,
282-
pub plan: Vec<u8>,
283-
pub session_id: String,
287+
pub plan: Arc<dyn ExecutionPlan>,
284288
pub launch_time: u64,
285-
pub props: HashMap<String, String>,
289+
pub session_id: String,
290+
pub props: Arc<HashMap<String, String>>,
291+
pub function_registry: Arc<SimpleFunctionRegistry>,
292+
}
293+
294+
#[derive(Debug)]
295+
pub struct SimpleFunctionRegistry {
296+
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
297+
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
298+
}
299+
300+
impl FunctionRegistry for SimpleFunctionRegistry {
301+
fn udfs(&self) -> HashSet<String> {
302+
self.scalar_functions.keys().cloned().collect()
303+
}
304+
305+
fn udf(&self, name: &str) -> datafusion::common::Result<Arc<ScalarUDF>> {
306+
let result = self.scalar_functions.get(name);
307+
308+
result.cloned().ok_or_else(|| {
309+
DataFusionError::Internal(format!(
310+
"There is no UDF named \"{name}\" in the TaskContext"
311+
))
312+
})
313+
}
314+
315+
fn udaf(&self, name: &str) -> datafusion::common::Result<Arc<AggregateUDF>> {
316+
let result = self.aggregate_functions.get(name);
317+
318+
result.cloned().ok_or_else(|| {
319+
DataFusionError::Internal(format!(
320+
"There is no UDAF named \"{name}\" in the TaskContext"
321+
))
322+
})
323+
}
286324
}

ballista/core/src/serde/scheduler/to_proto.rs

+2-31
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf;
2626

2727
use crate::serde::scheduler::{
2828
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
29-
PartitionLocation, PartitionStats, TaskDefinition,
29+
PartitionLocation, PartitionStats,
3030
};
3131
use datafusion::physical_plan::Partitioning;
32-
use protobuf::{
33-
action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime,
34-
};
32+
use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime};
3533

3634
impl TryInto<protobuf::Action> for Action {
3735
type Error = BallistaError;
@@ -242,30 +240,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
242240
}
243241
}
244242
}
245-
246-
#[allow(clippy::from_over_into)]
247-
impl Into<protobuf::TaskDefinition> for TaskDefinition {
248-
fn into(self) -> protobuf::TaskDefinition {
249-
let props = self
250-
.props
251-
.iter()
252-
.map(|(k, v)| KeyValuePair {
253-
key: k.to_owned(),
254-
value: v.to_owned(),
255-
})
256-
.collect::<Vec<_>>();
257-
258-
protobuf::TaskDefinition {
259-
task_id: self.task_id as u32,
260-
task_attempt_num: self.task_attempt_num as u32,
261-
job_id: self.job_id,
262-
stage_id: self.stage_id as u32,
263-
stage_attempt_num: self.stage_attempt_num as u32,
264-
partition_id: self.partition_id as u32,
265-
plan: self.plan,
266-
session_id: self.session_id,
267-
launch_time: self.launch_time,
268-
props,
269-
}
270-
}
271-
}

ballista/executor/src/execution_engine.rs

-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::SchemaRef;
1918
use async_trait::async_trait;
2019
use ballista_core::execution_plans::ShuffleWriterExec;
2120
use ballista_core::serde::protobuf::ShuffleWritePartition;
@@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
5251
) -> Result<Vec<ShuffleWritePartition>>;
5352

5453
fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
55-
56-
fn schema(&self) -> SchemaRef;
5754
}
5855

5956
pub struct DefaultExecutionEngine {}
@@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec {
111108
.await
112109
}
113110

114-
fn schema(&self) -> SchemaRef {
115-
self.shuffle_writer.schema()
116-
}
117-
118111
fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
119112
utils::collect_plan_metrics(&self.shuffle_writer)
120113
}

0 commit comments

Comments
 (0)