diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 253052fd0..ca3532ac7 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -40,6 +40,7 @@ async-trait = "0.1.41" ballista-core = { path = "../core", version = "0.8.0" } chrono = { version = "0.4", default-features = false } configure_me = "0.4.0" +dashmap = "5.4.0" datafusion = { git = "https://github.com/apache/arrow-datafusion", rev = "06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" } datafusion-proto = { git = "https://github.com/apache/arrow-datafusion", rev = "06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" } futures = "0.3" diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs index 2bc84f72c..bf036d6f1 100644 --- a/ballista/rust/executor/src/executor_server.rs +++ b/ballista/rust/executor/src/executor_server.rs @@ -21,7 +21,7 @@ use std::convert::TryInto; use std::ops::Deref; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::mpsc; use log::{debug, error, info, warn}; use tonic::transport::Channel; @@ -45,6 +45,7 @@ use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; use ballista_core::utils::{ collect_plan_metrics, create_grpc_client_connection, create_grpc_server, }; +use dashmap::DashMap; use datafusion::execution::context::TaskContext; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; @@ -57,7 +58,7 @@ use crate::shutdown::ShutdownNotifier; use crate::{as_task_status, TaskExecutionTimes}; type ServerHandle = JoinHandle>; -type SchedulerClients = Arc>>>; +type SchedulerClients = Arc>>; /// Wrap TaskDefinition with its curator scheduler id for task update to its specific curator scheduler later #[derive(Debug)] @@ -216,10 +217,7 @@ impl ExecutorServer Result, BallistaError> { - let scheduler = { - let schedulers = self.schedulers.read().await; - schedulers.get(scheduler_id).cloned() - }; + let scheduler = self.schedulers.get(scheduler_id).map(|value| value.clone()); // If channel does not exist, create a new one if let Some(scheduler) = scheduler { Ok(scheduler) @@ -229,8 +227,8 @@ impl ExecutorServer ExecutorServer)>) -> Result<()> { + /// Apply multiple operations in a single transaction. + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()> { let mut etcd = self.etcd.clone(); let txn_ops: Vec = ops .into_iter() - .map(|(ks, key, value)| { + .map(|(operation, ks, key)| { let key = format!("/{}/{:?}/{}", self.namespace, ks, key); - TxnOp::put(key, value, None) + match operation { + Operation::Put(value) => TxnOp::put(key, value, None), + Operation::Delete => TxnOp::delete(key, None), + } }) .collect(); - let txn = Txn::new().and_then(txn_ops); - - etcd.txn(txn) + etcd.txn(Txn::new().and_then(txn_ops)) .await .map_err(|e| { - error!("etcd put failed: {}", e); - ballista_error("etcd transaction put failed") + error!("etcd operation failed: {}", e); + ballista_error(&*format!("etcd operation failed: {}", e)) }) .map(|_| ()) } diff --git a/ballista/rust/scheduler/src/state/backend/mod.rs b/ballista/rust/scheduler/src/state/backend/mod.rs index b69403b2e..85c0f34d8 100644 --- a/ballista/rust/scheduler/src/state/backend/mod.rs +++ b/ballista/rust/scheduler/src/state/backend/mod.rs @@ -17,7 +17,7 @@ use ballista_core::error::Result; use clap::ArgEnum; -use futures::Stream; +use futures::{future, Stream}; use std::collections::HashSet; use std::fmt; use tokio::sync::OwnedMutexGuard; @@ -60,6 +60,12 @@ pub enum Keyspace { Heartbeats, } +#[derive(Debug, Eq, PartialEq, Hash)] +pub enum Operation { + Put(Vec), + Delete, +} + /// A trait that contains the necessary methods to save and retrieve the state and configuration of a cluster. #[tonic::async_trait] pub trait StateBackendClient: Send + Sync { @@ -90,8 +96,19 @@ pub trait StateBackendClient: Send + Sync { /// Saves the value into the provided key, overriding any previous data that might have been associated to that key. async fn put(&self, keyspace: Keyspace, key: String, value: Vec) -> Result<()>; - /// Save multiple values in a single transaction. Either all values should be saved, or all should fail - async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec)>) -> Result<()>; + /// Bundle multiple operation in a single transaction. Either all values should be saved, or all should fail. + /// It can support multiple types of operations and keyspaces. If the count of the unique keyspace is more than one, + /// more than one locks has to be acquired. + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()>; + /// Acquire mutex with specified IDs. + async fn acquire_locks( + &self, + mut ids: Vec<(Keyspace, &str)>, + ) -> Result>> { + // We always acquire locks in a specific order to avoid deadlocks. + ids.sort_by_key(|n| format!("/{:?}/{}", n.0, n.1)); + future::try_join_all(ids.into_iter().map(|(ks, key)| self.lock(ks, key))).await + } /// Atomically move the given key from one keyspace to another async fn mv( diff --git a/ballista/rust/scheduler/src/state/backend/standalone.rs b/ballista/rust/scheduler/src/state/backend/standalone.rs index 4e5dc063d..57bf7470c 100644 --- a/ballista/rust/scheduler/src/state/backend/standalone.rs +++ b/ballista/rust/scheduler/src/state/backend/standalone.rs @@ -25,7 +25,9 @@ use log::warn; use sled_package as sled; use tokio::sync::Mutex; -use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch, WatchEvent}; +use crate::state::backend::{ + Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent, +}; /// A [`StateBackendClient`] implementation that uses file-based storage to save cluster configuration. #[derive(Clone)] @@ -162,17 +164,20 @@ impl StateBackendClient for StandaloneClient { .map(|_| ()) } - async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec)>) -> Result<()> { + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()> { let mut batch = sled::Batch::default(); - for (ks, key, value) in ops { - let key = format!("/{:?}/{}", ks, key); - batch.insert(key.as_str(), value); + for (op, keyspace, key_str) in ops { + let key = format!("/{:?}/{}", &keyspace, key_str); + match op { + Operation::Put(value) => batch.insert(key.as_str(), value), + Operation::Delete => batch.remove(key.as_str()), + } } self.db.apply_batch(batch).map_err(|e| { warn!("sled transaction insert failed: {}", e); - ballista_error("sled insert failed") + ballista_error("sled operations failed") }) } @@ -279,7 +284,8 @@ impl Stream for SledWatch { mod tests { use super::{StandaloneClient, StateBackendClient, Watch, WatchEvent}; - use crate::state::backend::Keyspace; + use crate::state::backend::{Keyspace, Operation}; + use crate::state::with_locks; use futures::StreamExt; use std::result::Result; @@ -299,6 +305,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn multiple_operation() -> Result<(), Box> { + let client = create_instance()?; + let key = "key".to_string(); + let value = "value".as_bytes().to_vec(); + let locks = client + .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots, "")]) + .await?; + + let _r: ballista_core::error::Result<()> = with_locks(locks, async { + let txn_ops = vec![ + (Operation::Put(value.clone()), Keyspace::Slots, key.clone()), + ( + Operation::Put(value.clone()), + Keyspace::ActiveJobs, + key.clone(), + ), + ]; + client.apply_txn(txn_ops).await?; + Ok(()) + }) + .await; + + assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value); + assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?, value); + Ok(()) + } + #[tokio::test] async fn read_empty() -> Result<(), Box> { let client = create_instance()?; diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs b/ballista/rust/scheduler/src/state/executor_manager.rs index 1d135ef84..322e5a7bd 100644 --- a/ballista/rust/scheduler/src/state/executor_manager.rs +++ b/ballista/rust/scheduler/src/state/executor_manager.rs @@ -17,7 +17,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use crate::state::backend::{Keyspace, StateBackendClient, WatchEvent}; +use crate::state::backend::{Keyspace, Operation, StateBackendClient, WatchEvent}; use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock}; use ballista_core::error::{BallistaError, Result}; @@ -30,14 +30,14 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; use ballista_core::utils::create_grpc_client_connection; +use dashmap::{DashMap, DashSet}; use futures::StreamExt; use log::{debug, error, info}; -use parking_lot::RwLock; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tonic::transport::Channel; -type ExecutorClients = Arc>>>; +type ExecutorClients = Arc>>; /// Represents a task slot that is reserved (i.e. available for scheduling but not visible to the /// rest of the system). @@ -85,11 +85,11 @@ pub const DEFAULT_EXECUTOR_TIMEOUT_SECONDS: u64 = 180; pub(crate) struct ExecutorManager { state: Arc, // executor_id -> ExecutorMetadata map - executor_metadata: Arc>>, + executor_metadata: Arc>, // executor_id -> ExecutorHeartbeat map - executors_heartbeat: Arc>>, + executors_heartbeat: Arc>, // dead executor sets: - dead_executors: Arc>>, + dead_executors: Arc>, clients: ExecutorClients, } @@ -97,9 +97,9 @@ impl ExecutorManager { pub(crate) fn new(state: Arc) -> Self { Self { state, - executor_metadata: Arc::new(RwLock::new(HashMap::new())), - executors_heartbeat: Arc::new(RwLock::new(HashMap::new())), - dead_executors: Arc::new(RwLock::new(HashSet::new())), + executor_metadata: Arc::new(DashMap::new()), + executors_heartbeat: Arc::new(DashMap::new()), + dead_executors: Arc::new(DashSet::new()), clients: Default::default(), } } @@ -130,7 +130,7 @@ impl ExecutorManager { let alive_executors = self.get_alive_executors_within_one_minute(); - let mut txn_ops: Vec<(Keyspace, String, Vec)> = vec![]; + let mut txn_ops: Vec<(Operation, Keyspace, String)> = vec![]; for executor_id in alive_executors { let value = self.state.get(Keyspace::Slots, &executor_id).await?; @@ -146,14 +146,14 @@ impl ExecutorManager { let proto: protobuf::ExecutorData = data.into(); let new_data = encode_protobuf(&proto)?; - txn_ops.push((Keyspace::Slots, executor_id, new_data)); + txn_ops.push((Operation::Put(new_data), Keyspace::Slots, executor_id)); if desired == 0 { break; } } - self.state.put_txn(txn_ops).await?; + self.state.apply_txn(txn_ops).await?; let elapsed = start.elapsed(); info!( @@ -195,16 +195,16 @@ impl ExecutorManager { } } - let txn_ops: Vec<(Keyspace, String, Vec)> = executor_slots + let txn_ops: Vec<(Operation, Keyspace, String)> = executor_slots .into_iter() .map(|(executor_id, data)| { let proto: protobuf::ExecutorData = data.into(); let new_data = encode_protobuf(&proto)?; - Ok((Keyspace::Slots, executor_id, new_data)) + Ok((Operation::Put(new_data), Keyspace::Slots, executor_id)) }) .collect::>>()?; - self.state.put_txn(txn_ops).await?; + self.state.apply_txn(txn_ops).await?; let elapsed = start.elapsed(); info!( @@ -262,10 +262,7 @@ impl ExecutorManager { &self, executor_id: &str, ) -> Result> { - let client = { - let clients = self.clients.read(); - clients.get(executor_id).cloned() - }; + let client = self.clients.get(executor_id).map(|value| value.clone()); if let Some(client) = client { Ok(client) @@ -279,8 +276,7 @@ impl ExecutorManager { let client = ExecutorGrpcClient::new(connection); { - let mut clients = self.clients.write(); - clients.insert(executor_id.to_owned(), client.clone()); + self.clients.insert(executor_id.to_owned(), client.clone()); } Ok(client) } @@ -289,11 +285,10 @@ impl ExecutorManager { /// Get a list of all executors along with the timestamp of their last recorded heartbeat pub async fn get_executor_state(&self) -> Result> { let heartbeat_timestamps: Vec<(String, u64)> = { - let heartbeats = self.executors_heartbeat.read(); - - heartbeats + self.executors_heartbeat .iter() - .map(|(executor_id, heartbeat)| { + .map(|item| { + let (executor_id, heartbeat) = item.pair(); (executor_id.clone(), heartbeat.timestamp) }) .collect() @@ -316,8 +311,7 @@ impl ExecutorManager { executor_id: &str, ) -> Result { { - let metadata_cache = self.executor_metadata.read(); - if let Some(cached) = metadata_cache.get(executor_id) { + if let Some(cached) = self.executor_metadata.get(executor_id) { return Ok(cached.clone()); } } @@ -468,8 +462,8 @@ impl ExecutorManager { .put(Keyspace::Heartbeats, executor_id, value) .await?; - let mut executors_heartbeat = self.executors_heartbeat.write(); - executors_heartbeat.insert(heartbeat.executor_id.clone(), heartbeat); + self.executors_heartbeat + .insert(heartbeat.executor_id.clone(), heartbeat); Ok(()) } @@ -484,22 +478,19 @@ impl ExecutorManager { .put(Keyspace::Heartbeats, executor_id.clone(), value) .await?; - let mut executors_heartbeat = self.executors_heartbeat.write(); - executors_heartbeat.remove(&heartbeat.executor_id.clone()); - - let mut dead_executors = self.dead_executors.write(); - dead_executors.insert(executor_id); + self.executors_heartbeat + .remove(&heartbeat.executor_id.clone()); + self.dead_executors.insert(executor_id); Ok(()) } pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool { - self.dead_executors.read().contains(executor_id) + self.dead_executors.contains(executor_id) } /// Initialize the set of active executor heartbeats from storage async fn init_active_executor_heartbeats(&self) -> Result<()> { let heartbeats = self.state.scan(Keyspace::Heartbeats, None).await?; - let mut cache = self.executors_heartbeat.write(); for (_, value) in heartbeats { let data: protobuf::ExecutorHeartbeat = decode_protobuf(&value)?; @@ -508,7 +499,7 @@ impl ExecutorManager { status: Some(executor_status::Status::Active(_)), }) = data.status { - cache.insert(executor_id, data); + self.executors_heartbeat.insert(executor_id, data); } } Ok(()) @@ -520,10 +511,10 @@ impl ExecutorManager { &self, last_seen_ts_threshold: u64, ) -> HashSet { - let executors_heartbeat = self.executors_heartbeat.read(); - executors_heartbeat + self.executors_heartbeat .iter() - .filter_map(|(exec, heartbeat)| { + .filter_map(|pair| { + let (exec, heartbeat) = pair.pair(); (heartbeat.timestamp > last_seen_ts_threshold).then(|| exec.clone()) }) .collect() @@ -539,10 +530,11 @@ impl ExecutorManager { .unwrap_or_else(|| Duration::from_secs(0)) .as_secs(); - let lock = self.executors_heartbeat.read(); - let expired_executors = lock + let expired_executors = self + .executors_heartbeat .iter() - .filter_map(|(_exec, heartbeat)| { + .filter_map(|pair| { + let (_exec, heartbeat) = pair.pair(); (heartbeat.timestamp <= last_seen_threshold).then(|| heartbeat.clone()) }) .collect::>(); @@ -565,15 +557,15 @@ impl ExecutorManager { /// and maintain an in-memory copy of the executor heartbeats. struct ExecutorHeartbeatListener { state: Arc, - executors_heartbeat: Arc>>, - dead_executors: Arc>>, + executors_heartbeat: Arc>, + dead_executors: Arc>, } impl ExecutorHeartbeatListener { pub fn new( state: Arc, - executors_heartbeat: Arc>>, - dead_executors: Arc>>, + executors_heartbeat: Arc>, + dead_executors: Arc>, ) -> Self { Self { state, @@ -598,14 +590,13 @@ impl ExecutorHeartbeatListener { decode_protobuf::(&value) { let executor_id = data.executor_id.clone(); - let mut heartbeats = heartbeats.write(); // Remove dead executors if let Some(ExecutorStatus { status: Some(executor_status::Status::Dead(_)), }) = data.status { heartbeats.remove(&executor_id); - dead_executors.write().insert(executor_id); + dead_executors.insert(executor_id); } else { heartbeats.insert(executor_id, data); } diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index 94d69f6d2..0ecdb3a24 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -274,11 +274,23 @@ impl SchedulerState>(lock: Box, op: F) -> Out { - let mut lock = lock; +pub async fn with_lock>( + mut lock: Box, + op: F, +) -> Out { let result = op.await; lock.unlock().await; - + result +} +/// It takes multiple locks and reverse the order for releasing them to prevent a race condition. +pub async fn with_locks>( + locks: Vec>, + op: F, +) -> Out { + let result = op.await; + for mut lock in locks.into_iter().rev() { + lock.unlock().await; + } result } diff --git a/ballista/rust/scheduler/src/state/session_registry.rs b/ballista/rust/scheduler/src/state/session_registry.rs index 1281449bd..b6f214e5f 100644 --- a/ballista/rust/scheduler/src/state/session_registry.rs +++ b/ballista/rust/scheduler/src/state/session_registry.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. +use dashmap::DashMap; use datafusion::prelude::SessionContext; -use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::RwLock; /// A Registry holds all the datafusion session contexts pub struct SessionContextRegistry { /// A map from session_id to SessionContext - pub running_sessions: RwLock>>, + pub running_sessions: DashMap>, } impl Default for SessionContextRegistry { @@ -37,7 +36,7 @@ impl SessionContextRegistry { /// ['LocalFileSystem'] store is registered in by default to support read local files natively. pub fn new() -> Self { Self { - running_sessions: RwLock::new(HashMap::new()), + running_sessions: DashMap::new(), } } @@ -47,14 +46,14 @@ impl SessionContextRegistry { session_ctx: Arc, ) -> Option> { let session_id = session_ctx.session_id(); - let mut sessions = self.running_sessions.write().await; - sessions.insert(session_id, session_ctx) + self.running_sessions.insert(session_id, session_ctx) } /// Lookup the session context registered pub async fn lookup_session(&self, session_id: &str) -> Option> { - let sessions = self.running_sessions.read().await; - sessions.get(session_id).cloned() + self.running_sessions + .get(session_id) + .map(|value| value.clone()) } /// Remove a session from this registry. @@ -62,7 +61,9 @@ impl SessionContextRegistry { &self, session_id: &str, ) -> Option> { - let mut sessions = self.running_sessions.write().await; - sessions.remove(session_id) + match self.running_sessions.remove(session_id) { + None => None, + Some(value) => Some(value.1), + } } } diff --git a/ballista/rust/scheduler/src/state/task_manager.rs b/ballista/rust/scheduler/src/state/task_manager.rs index b7d6fd673..cf4708653 100644 --- a/ballista/rust/scheduler/src/state/task_manager.rs +++ b/ballista/rust/scheduler/src/state/task_manager.rs @@ -17,12 +17,12 @@ use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Keyspace, Lock, StateBackendClient}; +use crate::state::backend::{Keyspace, Operation, StateBackendClient}; use crate::state::execution_graph::{ ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription, }; use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; -use crate::state::{decode_protobuf, encode_protobuf, with_lock}; +use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks}; use ballista_core::config::BallistaConfig; #[cfg(not(test))] use ballista_core::error::BallistaError; @@ -35,6 +35,7 @@ use ballista_core::serde::protobuf::{ use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; +use dashmap::DashMap; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_proto::logical_plan::AsLogicalPlan; @@ -45,8 +46,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; - -type ExecutionGraphCache = Arc>>>>; +type ExecutionGraphCache = Arc>>>; // TODO move to configuration file /// Default max failure attempts for task level retry @@ -85,7 +85,7 @@ impl TaskManager session_builder, codec, scheduler_id, - active_job_cache: Arc::new(RwLock::new(HashMap::new())), + active_job_cache: Arc::new(DashMap::new()), } } @@ -111,9 +111,8 @@ impl TaskManager .await?; graph.revive(); - - let mut active_graph_cache = self.active_job_cache.write().await; - active_graph_cache.insert(job_id.to_owned(), Arc::new(RwLock::new(graph))); + self.active_job_cache + .insert(job_id.to_owned(), Arc::new(RwLock::new(graph))); Ok(()) } @@ -262,8 +261,8 @@ impl TaskManager let mut assignments: Vec<(String, TaskDescription)> = vec![]; let mut pending_tasks = 0usize; let mut assign_tasks = 0usize; - let job_cache = self.active_job_cache.read().await; - for (_job_id, graph) in job_cache.iter() { + for pairs in self.active_job_cache.iter() { + let (_job_id, graph) = pairs.pair(); let mut graph = graph.write().await; for reservation in free_reservations.iter().skip(assign_tasks) { if let Some(task) = graph.pop_next_task(&reservation.executor_id)? { @@ -321,7 +320,13 @@ impl TaskManager job_id: &str, failure_reason: String, ) -> Result> { - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; + let locks = self + .state + .acquire_locks(vec![ + (Keyspace::ActiveJobs, job_id), + (Keyspace::FailedJobs, job_id), + ]) + .await?; if let Some(graph) = self.get_active_execution_graph(job_id).await { let running_tasks = graph.read().await.running_tasks(); info!( @@ -329,12 +334,12 @@ impl TaskManager running_tasks.len(), job_id ); - self.fail_job_state(lock, job_id, failure_reason).await?; + with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; Ok(running_tasks) } else { // TODO listen the job state update event and fix task cancelling warn!("Fail to find job {} in the cache, unable to cancel tasks for job, fail the job state only.", job_id); - self.fail_job_state(lock, job_id, failure_reason).await?; + with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; Ok(vec![]) } } @@ -348,44 +353,55 @@ impl TaskManager failure_reason: String, ) -> Result<()> { debug!("Moving job {} from Active or Queue to Failed", job_id); - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; - self.fail_job_state(lock, job_id, failure_reason).await + let locks = self + .state + .acquire_locks(vec![ + (Keyspace::ActiveJobs, job_id), + (Keyspace::FailedJobs, job_id), + ]) + .await?; + with_locks(locks, self.fail_job_state(job_id, failure_reason)).await } - async fn fail_job_state( - &self, - lock: Box, - job_id: &str, - failure_reason: String, - ) -> Result<()> { - with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?; + async fn fail_job_state(&self, job_id: &str, failure_reason: String) -> Result<()> { + let txn_operations = |value: Vec| -> Vec<(Operation, Keyspace, String)> { + vec![ + (Operation::Delete, Keyspace::ActiveJobs, job_id.to_string()), + ( + Operation::Put(value), + Keyspace::FailedJobs, + job_id.to_string(), + ), + ] + }; - let value = if let Some(graph) = self.get_active_execution_graph(job_id).await { + let _res = if let Some(graph) = self.get_active_execution_graph(job_id).await { let mut graph = graph.write().await; - for stage_id in graph.running_stages() { - graph.fail_stage(stage_id, failure_reason.clone()); - } + let previous_status = graph.status(); graph.fail_job(failure_reason); - let graph = graph.clone(); - self.encode_execution_graph(graph)? + let value = self.encode_execution_graph(graph.clone())?; + let txn_ops = txn_operations(value); + let result = self.state.apply_txn(txn_ops).await; + if result.is_err() { + // Rollback + graph.update_status(previous_status); + warn!("Rollback Execution Graph state change since it did not persisted due to a possible connection error.") + }; + result } else { warn!("Fail to find job {} in the cache", job_id); - let status = JobStatus { status: Some(job_status::Status::Failed(FailedJob { error: failure_reason.clone(), })), }; - encode_protobuf(&status)? + let value = encode_protobuf(&status)?; + let txn_ops = txn_operations(value); + self.state.apply_txn(txn_ops).await }; - self.state - .put(Keyspace::FailedJobs, job_id.to_owned(), value) - .await?; - Ok(()) } - pub async fn update_job(&self, job_id: &str) -> Result<()> { debug!("Update job {} in Active", job_id); if let Some(graph) = self.get_active_execution_graph(job_id).await { @@ -408,10 +424,10 @@ impl TaskManager // Collect all the running task need to cancel when there are running stages rolled back. let mut running_tasks_to_cancel: Vec = vec![]; // Collect graphs we update so we can update them in storage - let mut updated_graphs: HashMap = HashMap::new(); + let updated_graphs: DashMap = DashMap::new(); { - let job_cache = self.active_job_cache.read().await; - for (job_id, graph) in job_cache.iter() { + for pairs in self.active_job_cache.iter() { + let (job_id, graph) = pairs.pair(); let mut graph = graph.write().await; let reset = graph.reset_stages_on_lost_executor(executor_id)?; if !reset.0.is_empty() { @@ -424,14 +440,14 @@ impl TaskManager let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; with_lock(lock, async { // Transactional update graphs - let txn_ops: Vec<(Keyspace, String, Vec)> = updated_graphs + let txn_ops: Vec<(Operation, Keyspace, String)> = updated_graphs .into_iter() .map(|(job_id, graph)| { let value = self.encode_execution_graph(graph)?; - Ok((Keyspace::ActiveJobs, job_id, value)) + Ok((Operation::Put(value), Keyspace::ActiveJobs, job_id)) }) .collect::>>()?; - self.state.put_txn(txn_ops).await?; + self.state.apply_txn(txn_ops).await?; Ok(running_tasks_to_cancel) }) .await @@ -524,8 +540,7 @@ impl TaskManager &self, job_id: &str, ) -> Option>> { - let active_graph_cache = self.active_job_cache.read().await; - active_graph_cache.get(job_id).cloned() + self.active_job_cache.get(job_id).map(|value| value.clone()) } /// Get the `ExecutionGraph` for the given job ID. This will search fist in the `ActiveJobs`