Skip to content

Commit

Permalink
feat(batch): mask failed worker during scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang28 committed May 17, 2023
1 parent 9de7f30 commit eb4a53d
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 65 deletions.
24 changes: 20 additions & 4 deletions src/frontend/src/scheduler/distributed/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::collections::HashMap;
use std::mem;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;

use anyhow::anyhow;
use arc_swap::ArcSwap;
Expand Down Expand Up @@ -826,21 +827,36 @@ impl StageRunner {
plan_fragment: PlanFragment,
worker: Option<WorkerNode>,
) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
let worker_node_addr = worker
.unwrap_or(self.worker_node_manager.next_random_worker()?)
.host
.unwrap();
let worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
let worker_node_addr = worker.host.unwrap();
let mask_failed_worker = || {
let duration = std::cmp::max(
Duration::from_secs(
self.ctx
.session
.env()
.meta_config()
.max_heartbeat_interval_secs as _,
) / 10,
Duration::from_secs(1),
);
self.worker_node_manager
.manager
.mask_worker_node(worker.id, duration);
};

let compute_client = self
.compute_client_pool
.get_by_addr((&worker_node_addr).into())
.await
.inspect_err(|_| mask_failed_worker())
.map_err(|e| anyhow!(e))?;

let t_id = task_id.task_id;
let stream_status = compute_client
.create_task(task_id, plan_fragment, self.epoch.clone())
.await
.inspect_err(|_| mask_failed_worker())
.map_err(|e| anyhow!(e))?
.fuse();

Expand Down
123 changes: 64 additions & 59 deletions src/frontend/src/scheduler/worker_node_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
// limitations under the License.

use std::collections::{BTreeMap, HashMap, HashSet, LinkedList, VecDeque};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, RwLock, RwLockReadGuard};
use std::time::Duration;

use anyhow::anyhow;
use itertools::Itertools;
Expand All @@ -31,6 +32,8 @@ use crate::scheduler::{SchedulerError, SchedulerResult};
/// `WorkerNodeManager` manages live worker nodes and table vnode mapping information.
pub struct WorkerNodeManager {
inner: RwLock<WorkerNodeManagerInner>,
/// Temporarily make worker invisible from serving cluster.
worker_node_mask: Arc<RwLock<HashSet<u32>>>,
}

struct WorkerNodeManagerInner {
Expand All @@ -57,6 +60,7 @@ impl WorkerNodeManager {
streaming_fragment_vnode_mapping: Default::default(),
serving_fragment_vnode_mapping: Default::default(),
}),
worker_node_mask: Arc::new(Default::default()),
}
}

Expand All @@ -67,7 +71,10 @@ impl WorkerNodeManager {
streaming_fragment_vnode_mapping: HashMap::new(),
serving_fragment_vnode_mapping: HashMap::new(),
});
Self { inner }
Self {
inner,
worker_node_mask: Arc::new(Default::default()),
}
}

pub fn list_worker_nodes(&self) -> Vec<WorkerNode> {
Expand Down Expand Up @@ -214,6 +221,26 @@ impl WorkerNodeManager {
let pu_mapping = guard.reschedule_serving(fragment_id)?;
Ok(pu_mapping)
}

fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<u32>> {
self.worker_node_mask.read().unwrap()
}

pub fn mask_worker_node(&self, worker_node_id: u32, duration: Duration) {
let mut worker_node_mask = self.worker_node_mask.write().unwrap();
if worker_node_mask.contains(&worker_node_id) {
return;
}
worker_node_mask.insert(worker_node_id);
let worker_node_mask_ref = self.worker_node_mask.clone();
tokio::spawn(async move {
tokio::time::sleep(duration).await;
worker_node_mask_ref
.write()
.unwrap()
.remove(&worker_node_id);
});
}
}

impl WorkerNodeManagerInner {
Expand Down Expand Up @@ -249,7 +276,8 @@ impl WorkerNodeManagerInner {
)
})?;
let serving_parallelism = std::cmp::min(
std::cmp::min(serving_pus_total_num, VirtualNode::COUNT),
serving_pus_total_num,
// Follow streaming vnode mapping is not a must.
streaming_vnode_mapping.iter_unique().count(),
);
assert!(serving_parallelism > 0);
Expand Down Expand Up @@ -297,15 +325,16 @@ impl WorkerNodeSelector {
if self.enable_barrier_read {
self.manager.list_streaming_worker_nodes().len()
} else {
self.manager.list_serving_worker_nodes().len()
self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
.len()
}
}

pub fn schedule_unit_count(&self) -> usize {
let worker_nodes = if self.enable_barrier_read {
self.manager.list_streaming_worker_nodes()
} else {
self.manager.list_serving_worker_nodes()
self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
};
worker_nodes
.iter()
Expand All @@ -327,31 +356,44 @@ impl WorkerNodeSelector {
))
})
} else {
self.manager.get_serving_fragment_mapping(fragment_id)
let origin = self.manager.get_serving_fragment_mapping(fragment_id)?;
if self.manager.worker_node_mask().is_empty() {
return Ok(origin);
}
let new_workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
let masked_mapping =
rebalance_serving_vnode(&origin, &new_workers, origin.iter_unique().count());
masked_mapping.ok_or_else(|| SchedulerError::EmptyWorkerNodes)
}
}

pub fn next_random_worker(&self) -> SchedulerResult<WorkerNode> {
let worker_nodes = if self.enable_barrier_read {
self.manager.list_streaming_worker_nodes()
} else {
self.manager.list_serving_worker_nodes()
self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
};
worker_nodes
.choose(&mut rand::thread_rng())
.ok_or_else(|| SchedulerError::EmptyWorkerNodes)
.map(|w| (*w).clone())
}

fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
let mask = self.manager.worker_node_mask();
origin
.into_iter()
.filter(|w| !mask.contains(&w.id))
.collect()
}
}

/// Calculate a new vnode mapping, keeping locality and balance on a best effort basis.
/// The strategy is similar to `rebalance_actor_vnode` used in meta node, but is modified to meet
/// new constraints accordingly.
fn rebalance_serving_vnode(
old_pu_mapping: &ParallelUnitMapping,
old_workers: &[WorkerNode],
added_workers: &[WorkerNode],
removed_workers: &[WorkerNode],
new_workers: &[WorkerNode],
max_parallelism: usize,
) -> Option<ParallelUnitMapping> {
let get_pu_map = |worker_nodes: &[WorkerNode]| {
Expand All @@ -361,11 +403,8 @@ fn rebalance_serving_vnode(
.map(|w| (w.id, w.parallel_units.clone()))
.collect::<BTreeMap<u32, Vec<ParallelUnit>>>()
};
let removed_pu_map = get_pu_map(removed_workers);
let mut new_pus: LinkedList<_> = get_pu_map(old_workers)
let mut new_pus: LinkedList<_> = get_pu_map(new_workers)
.into_iter()
.filter(|(w_id, _)| !removed_pu_map.contains_key(w_id))
.chain(get_pu_map(added_workers))
.map(|(_, pus)| pus.into_iter().sorted_by_key(|p| p.id))
.collect();
let serving_parallelism = std::cmp::min(
Expand Down Expand Up @@ -401,7 +440,6 @@ fn rebalance_serving_vnode(
}
let (expected, mut remain) = VirtualNode::COUNT.div_rem(&selected_pu_ids.len());
assert!(expected <= i32::MAX as usize);
// TODO comments
let mut balances: HashMap<ParallelUnitId, Balance> = HashMap::default();
for pu_id in &selected_pu_ids {
let mut balance = Balance {
Expand Down Expand Up @@ -576,9 +614,9 @@ mod tests {
..Default::default()
};
let pu_mapping = ParallelUnitMapping::build(&worker_1.parallel_units);
assert!(rebalance_serving_vnode(&pu_mapping, &[worker_1.clone()], &[], &[], 0).is_none());
assert!(rebalance_serving_vnode(&pu_mapping, &[worker_1.clone()], 0).is_none());
let re_pu_mapping =
rebalance_serving_vnode(&pu_mapping, &[worker_1.clone()], &[], &[], 10000).unwrap();
rebalance_serving_vnode(&pu_mapping, &[worker_1.clone()], 10000).unwrap();
assert_eq!(re_pu_mapping, pu_mapping);
assert_eq!(re_pu_mapping.iter_unique().count(), 1);
let worker_2 = WorkerNode {
Expand All @@ -587,14 +625,9 @@ mod tests {
property: Some(serving_property.clone()),
..Default::default()
};
let re_pu_mapping = rebalance_serving_vnode(
&re_pu_mapping,
&[worker_1.clone()],
&[worker_2.clone()],
&[],
10000,
)
.unwrap();
let re_pu_mapping =
rebalance_serving_vnode(&re_pu_mapping, &[worker_1.clone(), worker_2.clone()], 10000)
.unwrap();
assert_ne!(re_pu_mapping, pu_mapping);
assert_eq!(re_pu_mapping.iter_unique().count(), 51);
// 1*256+0 -> 5*51+1
Expand All @@ -608,9 +641,7 @@ mod tests {
};
let re_pu_mapping_2 = rebalance_serving_vnode(
&re_pu_mapping,
&[worker_1.clone(), worker_2.clone()],
&[worker_3.clone()],
&[],
&[worker_1.clone(), worker_2.clone(), worker_3.clone()],
10000,
)
.unwrap();
Expand All @@ -620,8 +651,6 @@ mod tests {
let re_pu_mapping = rebalance_serving_vnode(
&re_pu_mapping_2,
&[worker_1.clone(), worker_2.clone(), worker_3.clone()],
&[],
&[],
50,
)
.unwrap();
Expand All @@ -631,48 +660,24 @@ mod tests {
let re_pu_mapping_2 = rebalance_serving_vnode(
&re_pu_mapping,
&[worker_1.clone(), worker_2.clone(), worker_3.clone()],
&[],
&[],
10000,
)
.unwrap();
assert_eq!(re_pu_mapping_2.iter_unique().count(), 111);
// TODO count_same_vnode_mapping
let re_pu_mapping = rebalance_serving_vnode(
&re_pu_mapping_2,
&[worker_1.clone(), worker_2.clone(), worker_3.clone()],
&[],
&[worker_2.clone()],
&[worker_1.clone(), worker_3.clone()],
10000,
)
.unwrap();
// limited by total pu number
assert_eq!(re_pu_mapping.iter_unique().count(), 61);
// TODO count_same_vnode_mapping
assert!(rebalance_serving_vnode(
&re_pu_mapping,
&[worker_1.clone(), worker_3.clone()],
&[],
&[worker_1.clone(), worker_3.clone()],
10000
)
.is_none());
let re_pu_mapping = rebalance_serving_vnode(
&re_pu_mapping,
&[worker_1.clone(), worker_3.clone()],
&[],
&[worker_1.clone()],
10000,
)
.unwrap();
assert!(rebalance_serving_vnode(&re_pu_mapping, &[], 10000).is_none());
let re_pu_mapping =
rebalance_serving_vnode(&re_pu_mapping, &[worker_3.clone()], 10000).unwrap();
assert_eq!(re_pu_mapping.iter_unique().count(), 60);
assert!(rebalance_serving_vnode(
&re_pu_mapping,
&[worker_3.clone()],
&[],
&[worker_3.clone()],
10000
)
.is_none());
assert!(rebalance_serving_vnode(&re_pu_mapping, &[], 10000).is_none());
}
}
12 changes: 10 additions & 2 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use risingwave_common::catalog::DEFAULT_SCHEMA_NAME;
use risingwave_common::catalog::{
DEFAULT_DATABASE_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_ID,
};
use risingwave_common::config::{load_config, BatchConfig};
use risingwave_common::config::{load_config, BatchConfig, MetaConfig};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::monitor::process_linux::monitor_process;
use risingwave_common::session_config::ConfigMap;
Expand Down Expand Up @@ -111,6 +111,7 @@ pub struct FrontendEnv {
source_metrics: Arc<SourceMetrics>,

batch_config: BatchConfig,
meta_config: MetaConfig,

/// Track creating streaming jobs, used to cancel creating streaming job when cancel request
/// received.
Expand Down Expand Up @@ -157,6 +158,7 @@ impl FrontendEnv {
sessions_map: Arc::new(Mutex::new(HashMap::new())),
frontend_metrics: Arc::new(FrontendMetrics::for_test()),
batch_config: BatchConfig::default(),
meta_config: MetaConfig::default(),
source_metrics: Arc::new(SourceMetrics::default()),
creating_streaming_job_tracker: Arc::new(creating_streaming_tracker),
}
Expand All @@ -173,6 +175,7 @@ impl FrontendEnv {
info!("> version: {} ({})", RW_VERSION, GIT_SHA);

let batch_config = config.batch;
let meta_config = config.meta;

let frontend_address: HostAddr = opts
.advertise_addr
Expand All @@ -191,7 +194,7 @@ impl FrontendEnv {
WorkerType::Frontend,
&frontend_address,
Default::default(),
&config.meta,
&meta_config,
)
.await?;

Expand Down Expand Up @@ -322,6 +325,7 @@ impl FrontendEnv {
frontend_metrics,
sessions_map: Arc::new(Mutex::new(HashMap::new())),
batch_config,
meta_config,
source_metrics,
creating_streaming_job_tracker,
},
Expand Down Expand Up @@ -386,6 +390,10 @@ impl FrontendEnv {
&self.batch_config
}

pub fn meta_config(&self) -> &MetaConfig {
&self.meta_config
}

pub fn source_metrics(&self) -> Arc<SourceMetrics> {
self.source_metrics.clone()
}
Expand Down

0 comments on commit eb4a53d

Please sign in to comment.