diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 6d517bf9e..176b85c89 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -454,10 +454,9 @@ mod test { let graph = scheduler .state .task_manager - .get_active_execution_graph(job_id) - .await + .get_job_execution_graph(job_id) + .await? .unwrap(); - let graph = graph.read().await; if graph.is_successful() { break; } diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index d005e269a..c73612f51 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -300,7 +300,7 @@ impl TaskManager let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?; - if let Some(graph) = self.get_active_execution_graph(job_id).await { + if let Some(graph) = self.remove_active_execution_graph(job_id).await { let graph = graph.read().await.clone(); if graph.is_successful() { let value = self.encode_execution_graph(graph)?; @@ -423,7 +423,7 @@ impl TaskManager ] }; - let _res = if let Some(graph) = self.get_active_execution_graph(job_id).await { + let _res = if let Some(graph) = self.remove_active_execution_graph(job_id).await { let mut graph = graph.write().await; let previous_status = graph.status(); graph.fail_job(failure_reason); @@ -592,6 +592,14 @@ impl TaskManager self.active_job_cache.get(job_id).map(|value| value.clone()) } + /// Remove the `ExecutionGraph` for the given job ID from cache + pub(crate) async fn remove_active_execution_graph( + &self, + job_id: &str, + ) -> Option>> { + self.active_job_cache.remove(job_id).map(|value| value.1) + } + /// Get the `ExecutionGraph` for the given job ID. This will search fist in the `ActiveJobs` /// keyspace and then, if it doesn't find anything, search the `CompletedJobs` keyspace. pub(crate) async fn get_execution_graph(