From f8d48906f2b37479e8922c912eb3e62a8d6f9305 Mon Sep 17 00:00:00 2001
From: yangzhong <yangzhong@ebay.com>
Date: Tue, 30 May 2023 15:39:14 +0800
Subject: [PATCH 1/3] Revert "Only decode plan in `LaunchMultiTaskParams`  once
 (#743)"

This reverts commit 4e4842ce5221b8ce6ce39b82bb1346e337129b0d.
---
 .../core/src/serde/scheduler/from_proto.rs    |  68 +++---
 ballista/executor/src/execution_engine.rs     |   7 -
 ballista/executor/src/executor_server.rs      | 199 ++++++------------
 3 files changed, 90 insertions(+), 184 deletions(-)

diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs
index 17875e2b1..d93852118 100644
--- a/ballista/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/core/src/serde/scheduler/from_proto.rs
@@ -269,37 +269,34 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
     }
 }
 
-impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
+impl TryInto<TaskDefinition> for protobuf::TaskDefinition {
     type Error = BallistaError;
 
-    fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
+    fn try_into(self) -> Result<TaskDefinition, Self::Error> {
         let mut props = HashMap::new();
         for kv_pair in self.props {
             props.insert(kv_pair.key, kv_pair.value);
         }
 
-        Ok((
-            TaskDefinition {
-                task_id: self.task_id as usize,
-                task_attempt_num: self.task_attempt_num as usize,
-                job_id: self.job_id,
-                stage_id: self.stage_id as usize,
-                stage_attempt_num: self.stage_attempt_num as usize,
-                partition_id: self.partition_id as usize,
-                plan: vec![],
-                session_id: self.session_id,
-                launch_time: self.launch_time,
-                props,
-            },
-            self.plan,
-        ))
+        Ok(TaskDefinition {
+            task_id: self.task_id as usize,
+            task_attempt_num: self.task_attempt_num as usize,
+            job_id: self.job_id,
+            stage_id: self.stage_id as usize,
+            stage_attempt_num: self.stage_attempt_num as usize,
+            partition_id: self.partition_id as usize,
+            plan: self.plan,
+            session_id: self.session_id,
+            launch_time: self.launch_time,
+            props,
+        })
     }
 }
 
-impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
+impl TryInto<Vec<TaskDefinition>> for protobuf::MultiTaskDefinition {
     type Error = BallistaError;
 
-    fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
+    fn try_into(self) -> Result<Vec<TaskDefinition>, Self::Error> {
         let mut props = HashMap::new();
         for kv_pair in self.props {
             props.insert(kv_pair.key, kv_pair.value);
@@ -313,23 +310,20 @@ impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
         let launch_time = self.launch_time;
         let task_ids = self.task_ids;
 
-        Ok((
-            task_ids
-                .iter()
-                .map(|task_id| TaskDefinition {
-                    task_id: task_id.task_id as usize,
-                    task_attempt_num: task_id.task_attempt_num as usize,
-                    job_id: job_id.clone(),
-                    stage_id,
-                    stage_attempt_num,
-                    partition_id: task_id.partition_id as usize,
-                    plan: vec![],
-                    session_id: session_id.clone(),
-                    launch_time,
-                    props: props.clone(),
-                })
-                .collect(),
-            plan,
-        ))
+        Ok(task_ids
+            .iter()
+            .map(|task_id| TaskDefinition {
+                task_id: task_id.task_id as usize,
+                task_attempt_num: task_id.task_attempt_num as usize,
+                job_id: job_id.clone(),
+                stage_id,
+                stage_attempt_num,
+                partition_id: task_id.partition_id as usize,
+                plan: plan.clone(),
+                session_id: session_id.clone(),
+                launch_time,
+                props: props.clone(),
+            })
+            .collect())
     }
 }
diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs
index 965153298..5121f016b 100644
--- a/ballista/executor/src/execution_engine.rs
+++ b/ballista/executor/src/execution_engine.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::SchemaRef;
 use async_trait::async_trait;
 use ballista_core::execution_plans::ShuffleWriterExec;
 use ballista_core::serde::protobuf::ShuffleWritePartition;
@@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
     ) -> Result<Vec<ShuffleWritePartition>>;
 
     fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
-
-    fn schema(&self) -> SchemaRef;
 }
 
 pub struct DefaultExecutionEngine {}
@@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec {
             .await
     }
 
-    fn schema(&self) -> SchemaRef {
-        self.shuffle_writer.schema()
-    }
-
     fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
         utils::collect_plan_metrics(&self.shuffle_writer)
     }
diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs
index 9102923f6..3b47a23ab 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -47,12 +47,12 @@ use ballista_core::serde::BallistaCodec;
 use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
 use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
+use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan};
 use tokio::sync::mpsc::error::TryRecvError;
 use tokio::task::JoinHandle;
 
 use crate::cpu_bound_executor::DedicatedExecutor;
-use crate::execution_engine::QueryStageExecutor;
 use crate::executor::Executor;
 use crate::executor_process::ExecutorProcessConfig;
 use crate::shutdown::ShutdownNotifier;
@@ -65,8 +65,7 @@ type SchedulerClients = Arc<DashMap<String, SchedulerGrpcClient<Channel>>>;
 #[derive(Debug)]
 struct CuratorTaskDefinition {
     scheduler_id: String,
-    plan: Vec<u8>,
-    tasks: Vec<TaskDefinition>,
+    task: TaskDefinition,
 }
 
 /// Wrap TaskStatus with its curator scheduler id for task update to its specific curator scheduler later
@@ -298,67 +297,17 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         }
     }
 
-    async fn decode_task(
-        &self,
-        curator_task: TaskDefinition,
-        plan: &[u8],
-    ) -> Result<Arc<dyn QueryStageExecutor>, BallistaError> {
-        let task = curator_task;
-        let task_identity = task_identity(&task);
-        let task_props = task.props;
-        let mut config = ConfigOptions::new();
-        for (k, v) in task_props {
-            config.set(&k, &v)?;
-        }
-        let session_config = SessionConfig::from(config);
-
-        let mut task_scalar_functions = HashMap::new();
-        let mut task_aggregate_functions = HashMap::new();
-        for scalar_func in self.executor.scalar_functions.clone() {
-            task_scalar_functions.insert(scalar_func.0, scalar_func.1);
-        }
-        for agg_func in self.executor.aggregate_functions.clone() {
-            task_aggregate_functions.insert(agg_func.0, agg_func.1);
-        }
-
-        let task_context = Arc::new(TaskContext::new(
-            Some(task_identity),
-            task.session_id.clone(),
-            session_config,
-            task_scalar_functions,
-            task_aggregate_functions,
-            self.executor.runtime.clone(),
-        ));
-
-        let plan = U::try_decode(plan).and_then(|proto| {
-            proto.try_into_physical_plan(
-                task_context.deref(),
-                &self.executor.runtime,
-                self.codec.physical_extension_codec(),
-            )
-        })?;
-
-        Ok(self.executor.execution_engine.create_query_stage_exec(
-            task.job_id,
-            task.stage_id,
-            plan,
-            &self.executor.work_dir,
-        )?)
-    }
-
     async fn run_task(
         &self,
-        task_identity: &str,
-        scheduler_id: String,
-        curator_task: TaskDefinition,
-        query_stage_exec: Arc<dyn QueryStageExecutor>,
+        task_identity: String,
+        curator_task: CuratorTaskDefinition,
     ) -> Result<(), BallistaError> {
         let start_exec_time = SystemTime::now()
             .duration_since(UNIX_EPOCH)
             .unwrap()
             .as_millis() as u64;
         info!("Start to run task {}", task_identity);
-        let task = curator_task;
+        let task = curator_task.task;
         let task_props = task.props;
         let mut config = ConfigOptions::new();
         for (k, v) in task_props {
@@ -379,7 +328,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         let session_id = task.session_id;
         let runtime = self.executor.runtime.clone();
         let task_context = Arc::new(TaskContext::new(
-            Some(task_identity.to_string()),
+            Some(task_identity.clone()),
             session_id,
             session_config,
             task_scalar_functions,
@@ -387,11 +336,28 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             runtime.clone(),
         ));
 
+        let encoded_plan = task.plan.as_slice();
+
+        let plan: Arc<dyn ExecutionPlan> =
+            U::try_decode(encoded_plan).and_then(|proto| {
+                proto.try_into_physical_plan(
+                    task_context.deref(),
+                    runtime.deref(),
+                    self.codec.physical_extension_codec(),
+                )
+            })?;
+
         let task_id = task.task_id;
         let job_id = task.job_id;
         let stage_id = task.stage_id;
         let stage_attempt_num = task.stage_attempt_num;
         let partition_id = task.partition_id;
+        let query_stage_exec = self.executor.execution_engine.create_query_stage_exec(
+            job_id.clone(),
+            stage_id,
+            plan,
+            &self.executor.work_dir,
+        )?;
 
         let part = PartitionId {
             job_id: job_id.clone(),
@@ -440,6 +406,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             task_execution_times,
         );
 
+        let scheduler_id = curator_task.scheduler_id;
         let task_status_sender = self.executor_env.tx_task_status.clone();
         task_status_sender
             .send(CuratorTaskStatus {
@@ -505,18 +472,6 @@ struct TaskRunnerPool<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
     executor_server: Arc<ExecutorServer<T, U>>,
 }
 
-fn task_identity(task: &TaskDefinition) -> String {
-    format!(
-        "TID {} {}/{}.{}/{}.{}",
-        &task.task_id,
-        &task.job_id,
-        &task.stage_id,
-        &task.stage_attempt_num,
-        &task.partition_id,
-        &task.task_attempt_num,
-    )
-}
-
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T, U> {
     fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
         Self { executor_server }
@@ -638,64 +593,30 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T,
                         return;
                     }
                 };
-                if let Some(task) = maybe_task {
+                if let Some(curator_task) = maybe_task {
+                    let task_identity = format!(
+                        "TID {} {}/{}.{}/{}.{}",
+                        &curator_task.task.task_id,
+                        &curator_task.task.job_id,
+                        &curator_task.task.stage_id,
+                        &curator_task.task.stage_attempt_num,
+                        &curator_task.task.partition_id,
+                        &curator_task.task.task_attempt_num,
+                    );
+                    info!("Received task {:?}", &task_identity);
+
                     let server = executor_server.clone();
-                    let plan = task.plan;
-                    let curator_task = task.tasks[0].clone();
-                    let out: tokio::sync::oneshot::Receiver<
-                        Result<Arc<dyn QueryStageExecutor>, BallistaError>,
-                    > = dedicated_executor.spawn(async move {
-                        server.decode_task(curator_task, &plan).await
+                    dedicated_executor.spawn(async move {
+                        server
+                            .run_task(task_identity.clone(), curator_task)
+                            .await
+                            .unwrap_or_else(|e| {
+                                error!(
+                                    "Fail to run the task {:?} due to {:?}",
+                                    task_identity, e
+                                );
+                            });
                     });
-
-                    let plan = out.await;
-
-                    let plan = match plan {
-                        Ok(Ok(plan)) => plan,
-                        Ok(Err(e)) => {
-                            error!(
-                                "Failed to decode the plan of task {:?} due to {:?}",
-                                task_identity(&task.tasks[0]),
-                                e
-                            );
-                            return;
-                        }
-                        Err(e) => {
-                            error!(
-                                "Failed to receive error plan of task {:?} due to {:?}",
-                                task_identity(&task.tasks[0]),
-                                e
-                            );
-                            return;
-                        }
-                    };
-                    let scheduler_id = task.scheduler_id.clone();
-
-                    for curator_task in task.tasks {
-                        let plan = plan.clone();
-                        let scheduler_id = scheduler_id.clone();
-
-                        let task_identity = task_identity(&curator_task);
-                        info!("Received task {:?}", &task_identity);
-
-                        let server = executor_server.clone();
-                        dedicated_executor.spawn(async move {
-                            server
-                                .run_task(
-                                    &task_identity,
-                                    scheduler_id,
-                                    curator_task,
-                                    plan,
-                                )
-                                .await
-                                .unwrap_or_else(|e| {
-                                    error!(
-                                        "Fail to run the task {:?} due to {:?}",
-                                        task_identity, e
-                                    );
-                                });
-                        });
-                    }
                 } else {
                     info!("Channel is closed and will exit the task receive loop");
                     drop(task_runner_complete);
@@ -720,15 +641,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
         } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for task in tasks {
-            let (task_def, plan) = task
-                .try_into()
-                .map_err(|e| Status::invalid_argument(format!("{e}")))?;
-
             task_sender
                 .send(CuratorTaskDefinition {
                     scheduler_id: scheduler_id.clone(),
-                    plan,
-                    tasks: vec![task_def],
+                    task: task
+                        .try_into()
+                        .map_err(|e| Status::invalid_argument(format!("{e}")))?,
                 })
                 .await
                 .unwrap();
@@ -748,17 +666,18 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
         } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for multi_task in multi_tasks {
-            let (multi_task, plan): (Vec<TaskDefinition>, Vec<u8>) = multi_task
+            let multi_task: Vec<TaskDefinition> = multi_task
                 .try_into()
                 .map_err(|e| Status::invalid_argument(format!("{e}")))?;
-            task_sender
-                .send(CuratorTaskDefinition {
-                    scheduler_id: scheduler_id.clone(),
-                    plan,
-                    tasks: multi_task,
-                })
-                .await
-                .unwrap();
+            for task in multi_task {
+                task_sender
+                    .send(CuratorTaskDefinition {
+                        scheduler_id: scheduler_id.clone(),
+                        task,
+                    })
+                    .await
+                    .unwrap();
+            }
         }
         Ok(Response::new(LaunchMultiTaskResult { success: true }))
     }

From 265959127e3880910dabd7181f3478eb19316c0f Mon Sep 17 00:00:00 2001
From: yangzhong <yangzhong@ebay.com>
Date: Mon, 26 Jun 2023 18:07:31 +0800
Subject: [PATCH 2/3] Refactor the TaskDefinition by changing encoding
 execution plan to the decoded one

---
 .../core/src/serde/scheduler/from_proto.rs    | 170 +++++++++++++-----
 ballista/core/src/serde/scheduler/mod.rs      |  46 ++++-
 ballista/core/src/serde/scheduler/to_proto.rs |  33 +---
 ballista/executor/src/executor_server.rs      |  94 +++++-----
 4 files changed, 215 insertions(+), 128 deletions(-)

diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs
index d93852118..545896d8b 100644
--- a/ballista/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/core/src/serde/scheduler/from_proto.rs
@@ -16,10 +16,15 @@
 // under the License.
 
 use chrono::{TimeZone, Utc};
+use datafusion::common::tree_node::{Transformed, TreeNode};
+use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
 use datafusion::physical_plan::metrics::{
     Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
 };
-use datafusion::physical_plan::Metric;
+use datafusion::physical_plan::{ExecutionPlan, Metric};
+use datafusion_proto::logical_plan::AsLogicalPlan;
+use datafusion_proto::physical_plan::AsExecutionPlan;
 use std::collections::HashMap;
 use std::convert::TryInto;
 use std::sync::Arc;
@@ -28,10 +33,10 @@ use std::time::Duration;
 use crate::error::BallistaError;
 use crate::serde::scheduler::{
     Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
-    PartitionLocation, PartitionStats, TaskDefinition,
+    PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
 };
 
-use crate::serde::protobuf;
+use crate::serde::{protobuf, BallistaCodec};
 use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};
 
 impl TryInto<Action> for protobuf::Action {
@@ -269,61 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
     }
 }
 
-impl TryInto<TaskDefinition> for protobuf::TaskDefinition {
-    type Error = BallistaError;
-
-    fn try_into(self) -> Result<TaskDefinition, Self::Error> {
-        let mut props = HashMap::new();
-        for kv_pair in self.props {
-            props.insert(kv_pair.key, kv_pair.value);
-        }
+pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
+    task: protobuf::TaskDefinition,
+    runtime: Arc<RuntimeEnv>,
+    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+    codec: BallistaCodec<T, U>,
+) -> Result<TaskDefinition, BallistaError> {
+    let mut props = HashMap::new();
+    for kv_pair in task.props {
+        props.insert(kv_pair.key, kv_pair.value);
+    }
+    let props = Arc::new(props);
 
-        Ok(TaskDefinition {
-            task_id: self.task_id as usize,
-            task_attempt_num: self.task_attempt_num as usize,
-            job_id: self.job_id,
-            stage_id: self.stage_id as usize,
-            stage_attempt_num: self.stage_attempt_num as usize,
-            partition_id: self.partition_id as usize,
-            plan: self.plan,
-            session_id: self.session_id,
-            launch_time: self.launch_time,
-            props,
-        })
+    let mut task_scalar_functions = HashMap::new();
+    let mut task_aggregate_functions = HashMap::new();
+    // TODO combine the functions from Executor's functions and TaskDefinition's function resources
+    for scalar_func in scalar_functions {
+        task_scalar_functions.insert(scalar_func.0, scalar_func.1);
+    }
+    for agg_func in aggregate_functions {
+        task_aggregate_functions.insert(agg_func.0, agg_func.1);
     }
+    let function_registry = Arc::new(SimpleFunctionRegistry {
+        scalar_functions: task_scalar_functions,
+        aggregate_functions: task_aggregate_functions,
+    });
+
+    let encoded_plan = task.plan.as_slice();
+    let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
+        proto.try_into_physical_plan(
+            function_registry.as_ref(),
+            runtime.as_ref(),
+            codec.physical_extension_codec(),
+        )
+    })?;
+
+    let job_id = task.job_id;
+    let stage_id = task.stage_id as usize;
+    let partition_id = task.partition_id as usize;
+    let task_attempt_num = task.task_attempt_num as usize;
+    let stage_attempt_num = task.stage_attempt_num as usize;
+    let launch_time = task.launch_time;
+    let task_id = task.task_id as usize;
+    let session_id = task.session_id;
+
+    Ok(TaskDefinition {
+        task_id,
+        task_attempt_num,
+        job_id,
+        stage_id,
+        stage_attempt_num,
+        partition_id,
+        plan,
+        launch_time,
+        session_id,
+        props,
+        function_registry,
+    })
 }
 
-impl TryInto<Vec<TaskDefinition>> for protobuf::MultiTaskDefinition {
-    type Error = BallistaError;
+pub fn get_task_definition_vec<
+    T: 'static + AsLogicalPlan,
+    U: 'static + AsExecutionPlan,
+>(
+    multi_task: protobuf::MultiTaskDefinition,
+    runtime: Arc<RuntimeEnv>,
+    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+    codec: BallistaCodec<T, U>,
+) -> Result<Vec<TaskDefinition>, BallistaError> {
+    let mut props = HashMap::new();
+    for kv_pair in multi_task.props {
+        props.insert(kv_pair.key, kv_pair.value);
+    }
+    let props = Arc::new(props);
 
-    fn try_into(self) -> Result<Vec<TaskDefinition>, Self::Error> {
-        let mut props = HashMap::new();
-        for kv_pair in self.props {
-            props.insert(kv_pair.key, kv_pair.value);
-        }
+    let mut task_scalar_functions = HashMap::new();
+    let mut task_aggregate_functions = HashMap::new();
+    // TODO combine the functions from Executor's functions and TaskDefinition's function resources
+    for scalar_func in scalar_functions {
+        task_scalar_functions.insert(scalar_func.0, scalar_func.1);
+    }
+    for agg_func in aggregate_functions {
+        task_aggregate_functions.insert(agg_func.0, agg_func.1);
+    }
+    let function_registry = Arc::new(SimpleFunctionRegistry {
+        scalar_functions: task_scalar_functions,
+        aggregate_functions: task_aggregate_functions,
+    });
 
-        let plan = self.plan;
-        let session_id = self.session_id;
-        let job_id = self.job_id;
-        let stage_id = self.stage_id as usize;
-        let stage_attempt_num = self.stage_attempt_num as usize;
-        let launch_time = self.launch_time;
-        let task_ids = self.task_ids;
+    let encoded_plan = multi_task.plan.as_slice();
+    let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
+        proto.try_into_physical_plan(
+            function_registry.as_ref(),
+            runtime.as_ref(),
+            codec.physical_extension_codec(),
+        )
+    })?;
+
+    let job_id = multi_task.job_id;
+    let stage_id = multi_task.stage_id as usize;
+    let stage_attempt_num = multi_task.stage_attempt_num as usize;
+    let launch_time = multi_task.launch_time;
+    let task_ids = multi_task.task_ids;
+    let session_id = multi_task.session_id;
 
-        Ok(task_ids
-            .iter()
-            .map(|task_id| TaskDefinition {
+    task_ids
+        .iter()
+        .map(|task_id| {
+            Ok(TaskDefinition {
                 task_id: task_id.task_id as usize,
                 task_attempt_num: task_id.task_attempt_num as usize,
                 job_id: job_id.clone(),
                 stage_id,
                 stage_attempt_num,
                 partition_id: task_id.partition_id as usize,
-                plan: plan.clone(),
-                session_id: session_id.clone(),
+                plan: reset_metrics_for_execution_plan(plan.clone())?,
                 launch_time,
+                session_id: session_id.clone(),
                 props: props.clone(),
+                function_registry: function_registry.clone(),
             })
-            .collect())
-    }
+        })
+        .collect()
+}
+
+fn reset_metrics_for_execution_plan(
+    plan: Arc<dyn ExecutionPlan>,
+) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
+    plan.transform(&|plan| {
+        let children = plan.children().clone();
+        plan.with_new_children(children).map(Transformed::Yes)
+    })
+    .map_err(BallistaError::DataFusionError)
 }
diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs
index 6e9440a36..96c4e0fa7 100644
--- a/ballista/core/src/serde/scheduler/mod.rs
+++ b/ballista/core/src/serde/scheduler/mod.rs
@@ -15,12 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::collections::HashSet;
+use std::fmt::Debug;
 use std::{collections::HashMap, fmt, sync::Arc};
 
 use datafusion::arrow::array::{
     ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
 };
 use datafusion::arrow::datatypes::{DataType, Field};
+use datafusion::common::DataFusionError;
+use datafusion::execution::FunctionRegistry;
+use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::physical_plan::Partitioning;
 use serde::Serialize;
@@ -271,7 +276,7 @@ impl ExecutePartitionResult {
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Clone, Debug)]
 pub struct TaskDefinition {
     pub task_id: usize,
     pub task_attempt_num: usize,
@@ -279,8 +284,41 @@ pub struct TaskDefinition {
     pub stage_id: usize,
     pub stage_attempt_num: usize,
     pub partition_id: usize,
-    pub plan: Vec<u8>,
-    pub session_id: String,
+    pub plan: Arc<dyn ExecutionPlan>,
     pub launch_time: u64,
-    pub props: HashMap<String, String>,
+    pub session_id: String,
+    pub props: Arc<HashMap<String, String>>,
+    pub function_registry: Arc<SimpleFunctionRegistry>,
+}
+
+#[derive(Debug)]
+pub struct SimpleFunctionRegistry {
+    pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+    pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+}
+
+impl FunctionRegistry for SimpleFunctionRegistry {
+    fn udfs(&self) -> HashSet<String> {
+        self.scalar_functions.keys().cloned().collect()
+    }
+
+    fn udf(&self, name: &str) -> datafusion::common::Result<Arc<ScalarUDF>> {
+        let result = self.scalar_functions.get(name);
+
+        result.cloned().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "There is no UDF named \"{name}\" in the TaskContext"
+            ))
+        })
+    }
+
+    fn udaf(&self, name: &str) -> datafusion::common::Result<Arc<AggregateUDF>> {
+        let result = self.aggregate_functions.get(name);
+
+        result.cloned().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "There is no UDAF named \"{name}\" in the TaskContext"
+            ))
+        })
+    }
 }
diff --git a/ballista/core/src/serde/scheduler/to_proto.rs b/ballista/core/src/serde/scheduler/to_proto.rs
index ccb5ec427..6ceb1dd6e 100644
--- a/ballista/core/src/serde/scheduler/to_proto.rs
+++ b/ballista/core/src/serde/scheduler/to_proto.rs
@@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf;
 
 use crate::serde::scheduler::{
     Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
-    PartitionLocation, PartitionStats, TaskDefinition,
+    PartitionLocation, PartitionStats,
 };
 use datafusion::physical_plan::Partitioning;
-use protobuf::{
-    action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime,
-};
+use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime};
 
 impl TryInto<protobuf::Action> for Action {
     type Error = BallistaError;
@@ -242,30 +240,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
         }
     }
 }
-
-#[allow(clippy::from_over_into)]
-impl Into<protobuf::TaskDefinition> for TaskDefinition {
-    fn into(self) -> protobuf::TaskDefinition {
-        let props = self
-            .props
-            .iter()
-            .map(|(k, v)| KeyValuePair {
-                key: k.to_owned(),
-                value: v.to_owned(),
-            })
-            .collect::<Vec<_>>();
-
-        protobuf::TaskDefinition {
-            task_id: self.task_id as u32,
-            task_attempt_num: self.task_attempt_num as u32,
-            job_id: self.job_id,
-            stage_id: self.stage_id as u32,
-            stage_attempt_num: self.stage_attempt_num as u32,
-            partition_id: self.partition_id as u32,
-            plan: self.plan,
-            session_id: self.session_id,
-            launch_time: self.launch_time,
-            props,
-        }
-    }
-}
diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs
index 3b47a23ab..861455313 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -16,11 +16,8 @@
 // under the License.
 
 use ballista_core::BALLISTA_VERSION;
-use datafusion::config::ConfigOptions;
-use datafusion::prelude::SessionConfig;
 use std::collections::HashMap;
 use std::convert::TryInto;
-use std::ops::Deref;
 use std::path::{Path, PathBuf};
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
@@ -41,13 +38,17 @@ use ballista_core::serde::protobuf::{
     LaunchTaskResult, RegisterExecutorParams, RemoveJobDataParams, RemoveJobDataResult,
     StopExecutorParams, StopExecutorResult, TaskStatus, UpdateTaskStatusParams,
 };
+use ballista_core::serde::scheduler::from_proto::{
+    get_task_definition, get_task_definition_vec,
+};
 use ballista_core::serde::scheduler::PartitionId;
 use ballista_core::serde::scheduler::TaskDefinition;
 use ballista_core::serde::BallistaCodec;
 use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
 use dashmap::DashMap;
-use datafusion::execution::context::TaskContext;
-use datafusion::physical_plan::ExecutionPlan;
+use datafusion::config::ConfigOptions;
+use datafusion::execution::TaskContext;
+use datafusion::prelude::SessionConfig;
 use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan};
 use tokio::sync::mpsc::error::TryRecvError;
 use tokio::task::JoinHandle;
@@ -308,50 +309,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             .as_millis() as u64;
         info!("Start to run task {}", task_identity);
         let task = curator_task.task;
-        let task_props = task.props;
-        let mut config = ConfigOptions::new();
-        for (k, v) in task_props {
-            config.set(&k, &v)?;
-        }
-        let session_config = SessionConfig::from(config);
-
-        let mut task_scalar_functions = HashMap::new();
-        let mut task_aggregate_functions = HashMap::new();
-        // TODO combine the functions from Executor's functions and TaskDefintion's function resources
-        for scalar_func in self.executor.scalar_functions.clone() {
-            task_scalar_functions.insert(scalar_func.0, scalar_func.1);
-        }
-        for agg_func in self.executor.aggregate_functions.clone() {
-            task_aggregate_functions.insert(agg_func.0, agg_func.1);
-        }
-
-        let session_id = task.session_id;
-        let runtime = self.executor.runtime.clone();
-        let task_context = Arc::new(TaskContext::new(
-            Some(task_identity.clone()),
-            session_id,
-            session_config,
-            task_scalar_functions,
-            task_aggregate_functions,
-            runtime.clone(),
-        ));
-
-        let encoded_plan = task.plan.as_slice();
-
-        let plan: Arc<dyn ExecutionPlan> =
-            U::try_decode(encoded_plan).and_then(|proto| {
-                proto.try_into_physical_plan(
-                    task_context.deref(),
-                    runtime.deref(),
-                    self.codec.physical_extension_codec(),
-                )
-            })?;
 
         let task_id = task.task_id;
         let job_id = task.job_id;
         let stage_id = task.stage_id;
         let stage_attempt_num = task.stage_attempt_num;
         let partition_id = task.partition_id;
+        let plan = task.plan;
+
         let query_stage_exec = self.executor.execution_engine.create_query_stage_exec(
             job_id.clone(),
             stage_id,
@@ -365,6 +330,27 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             partition_id,
         };
 
+        let task_context = {
+            let task_props = task.props;
+            let mut config = ConfigOptions::new();
+            for (k, v) in task_props.iter() {
+                config.set(k, v)?;
+            }
+            let session_config = SessionConfig::from(config);
+
+            let function_registry = task.function_registry;
+            let runtime = self.executor.runtime.clone();
+
+            Arc::new(TaskContext::new(
+                Some(task_identity.clone()),
+                task.session_id,
+                session_config,
+                function_registry.scalar_functions.clone(),
+                function_registry.aggregate_functions.clone(),
+                runtime,
+            ))
+        };
+
         info!("Start to execute shuffle write for task {}", task_identity);
 
         let execution_result = self
@@ -644,9 +630,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
             task_sender
                 .send(CuratorTaskDefinition {
                     scheduler_id: scheduler_id.clone(),
-                    task: task
-                        .try_into()
-                        .map_err(|e| Status::invalid_argument(format!("{e}")))?,
+                    task: get_task_definition(
+                        task,
+                        self.executor.runtime.clone(),
+                        self.executor.scalar_functions.clone(),
+                        self.executor.aggregate_functions.clone(),
+                        self.codec.clone(),
+                    )
+                    .map_err(|e| Status::invalid_argument(format!("{e}")))?,
                 })
                 .await
                 .unwrap();
@@ -666,9 +657,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
         } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for multi_task in multi_tasks {
-            let multi_task: Vec<TaskDefinition> = multi_task
-                .try_into()
-                .map_err(|e| Status::invalid_argument(format!("{e}")))?;
+            let multi_task: Vec<TaskDefinition> = get_task_definition_vec(
+                multi_task,
+                self.executor.runtime.clone(),
+                self.executor.scalar_functions.clone(),
+                self.executor.aggregate_functions.clone(),
+                self.codec.clone(),
+            )
+            .map_err(|e| Status::invalid_argument(format!("{e}")))?;
             for task in multi_task {
                 task_sender
                     .send(CuratorTaskDefinition {

From 9c1855b728b9f06519f46d977ab128449ecfb552 Mon Sep 17 00:00:00 2001
From: yangzhong <yangzhong@ebay.com>
Date: Mon, 26 Jun 2023 19:19:21 +0800
Subject: [PATCH 3/3] Refine the error handling of run_task in the
 executor_server

---
 ballista/executor/src/executor_server.rs | 51 ++++++++++++------------
 1 file changed, 25 insertions(+), 26 deletions(-)

diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs
index 861455313..2892cb0b1 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -298,11 +298,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         }
     }
 
-    async fn run_task(
-        &self,
-        task_identity: String,
-        curator_task: CuratorTaskDefinition,
-    ) -> Result<(), BallistaError> {
+    /// This method should not return Err. If task fails, a failure task status should be sent
+    /// to the channel to notify the scheduler.
+    async fn run_task(&self, task_identity: String, curator_task: CuratorTaskDefinition) {
         let start_exec_time = SystemTime::now()
             .duration_since(UNIX_EPOCH)
             .unwrap()
@@ -317,24 +315,30 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         let partition_id = task.partition_id;
         let plan = task.plan;
 
-        let query_stage_exec = self.executor.execution_engine.create_query_stage_exec(
-            job_id.clone(),
-            stage_id,
-            plan,
-            &self.executor.work_dir,
-        )?;
-
         let part = PartitionId {
             job_id: job_id.clone(),
             stage_id,
             partition_id,
         };
 
+        let query_stage_exec = self
+            .executor
+            .execution_engine
+            .create_query_stage_exec(
+                job_id.clone(),
+                stage_id,
+                plan,
+                &self.executor.work_dir,
+            )
+            .unwrap();
+
         let task_context = {
             let task_props = task.props;
             let mut config = ConfigOptions::new();
             for (k, v) in task_props.iter() {
-                config.set(k, v)?;
+                if let Err(e) = config.set(k, v) {
+                    debug!("Fail to set session config for ({},{}): {:?}", k, v, e);
+                }
             }
             let session_config = SessionConfig::from(config);
 
@@ -366,10 +370,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         debug!("Statistics: {:?}", execution_result);
 
         let plan_metrics = query_stage_exec.collect_plan_metrics();
-        let operator_metrics = plan_metrics
+        let operator_metrics = match plan_metrics
             .into_iter()
             .map(|m| m.try_into())
-            .collect::<Result<Vec<_>, BallistaError>>()?;
+            .collect::<Result<Vec<_>, BallistaError>>()
+        {
+            Ok(metrics) => Some(metrics),
+            Err(_) => None,
+        };
         let executor_id = &self.executor.metadata.id;
 
         let end_exec_time = SystemTime::now()
@@ -388,7 +396,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             task_id,
             stage_attempt_num,
             part,
-            Some(operator_metrics),
+            operator_metrics,
             task_execution_times,
         );
 
@@ -401,7 +409,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             })
             .await
             .unwrap();
-        Ok(())
     }
 
     // TODO populate with real metrics
@@ -593,15 +600,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T,
 
                     let server = executor_server.clone();
                     dedicated_executor.spawn(async move {
-                        server
-                            .run_task(task_identity.clone(), curator_task)
-                            .await
-                            .unwrap_or_else(|e| {
-                                error!(
-                                    "Fail to run the task {:?} due to {:?}",
-                                    task_identity, e
-                                );
-                            });
+                        server.run_task(task_identity.clone(), curator_task).await;
                     });
                 } else {
                     info!("Channel is closed and will exit the task receive loop");