Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wait_drained to SchedulerServer and Executor #41

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ mod tests {

fn get_test_partition_locations(n: usize, path: String) -> Vec<PartitionLocation> {
(0..n)
.into_iter()
.map(|partition_id| PartitionLocation {
map_partition_id: 0,
partition_id: PartitionId {
Expand Down
74 changes: 46 additions & 28 deletions ballista/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

use dashmap::DashMap;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::metrics::ExecutorMetricsCollector;
use ballista_core::error::BallistaError;
Expand All @@ -37,23 +34,10 @@ use datafusion::physical_plan::udaf::AggregateUDF;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use futures::future::AbortHandle;
use tokio::sync::watch;

use ballista_core::serde::scheduler::PartitionId;

pub struct TasksDrainedFuture(pub Arc<Executor>);

impl Future for TasksDrainedFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.abort_handles.len() > 0 {
Poll::Pending
} else {
Poll::Ready(())
}
}
}

type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;

/// Ballista executor
Expand Down Expand Up @@ -82,6 +66,9 @@ pub struct Executor {

/// Handles to abort executing tasks
abort_handles: AbortHandles,

drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

impl Executor {
Expand All @@ -93,17 +80,15 @@ impl Executor {
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
) -> Self {
Self {
Self::with_functions(
metadata,
work_dir: work_dir.to_owned(),
// TODO add logic to dynamically load UDF/UDAFs libs from files
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
work_dir,
runtime,
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
}
HashMap::new(),
HashMap::new(),
)
}

pub fn with_functions(
Expand All @@ -115,6 +100,8 @@ impl Executor {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
) -> Self {
let (drained, check_drained) = watch::channel(());

Self {
metadata,
work_dir: work_dir.to_owned(),
Expand All @@ -124,6 +111,8 @@ impl Executor {
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
drained: Arc::new(drained),
check_drained,
}
}
}
Expand All @@ -147,9 +136,11 @@ impl Executor {
self.abort_handles
.insert((task_id, partition.clone()), abort_handle);

let partitions = task.await??;
let partitions = task.await;

self.remove_handle(task_id, partition.clone());

self.abort_handles.remove(&(task_id, partition.clone()));
let partitions = partitions??;

self.metrics_collector.record_stage(
&partition.job_id,
Expand Down Expand Up @@ -196,14 +187,14 @@ impl Executor {
stage_id: usize,
partition_id: usize,
) -> Result<bool, BallistaError> {
if let Some((_, handle)) = self.abort_handles.remove(&(
if let Some((_, handle)) = self.remove_handle(
task_id,
PartitionId {
job_id,
stage_id,
partition_id,
},
)) {
) {
handle.abort();
Ok(true)
} else {
Expand All @@ -218,6 +209,33 @@ impl Executor {
pub fn active_task_count(&self) -> usize {
self.abort_handles.len()
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();
loop {
if self.active_task_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}

fn remove_handle(
&self,
task_id: usize,
partition: PartitionId,
) -> Option<((usize, PartitionId), AbortHandle)> {
let removed = self.abort_handles.remove(&(task_id, partition));

if self.active_task_count() == 0 {
self.drained.send_replace(());
}

removed
}
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions ballista/executor/src/executor_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use ballista_core::utils::{
};
use ballista_core::BALLISTA_VERSION;

use crate::executor::{Executor, TasksDrainedFuture};
use crate::executor::Executor;
use crate::executor_server::TERMINATING;
use crate::flight_service::BallistaFlightService;
use crate::metrics::LoggingMetricsCollector;
Expand Down Expand Up @@ -301,7 +301,7 @@ pub async fn start_executor_process(opt: ExecutorProcessConfig) -> Result<()> {
shutdown_noti.subscribe_for_shutdown(),
)));

let tasks_drained = TasksDrainedFuture(executor);
let tasks_drained = executor.wait_drained();

// Concurrently run the service checking and listen for the `shutdown` signal and wait for the stop request coming.
// The check_services runs until an error is encountered, so under normal circumstances, this `select!` statement runs
Expand Down
6 changes: 3 additions & 3 deletions ballista/scheduler/src/cluster/event/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ mod test {
}
});

let expected: Vec<i32> = (0..100).into_iter().collect();
let expected: Vec<i32> = (0..100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down Expand Up @@ -233,7 +233,7 @@ mod test {

// When we reach capacity older events should be dropped so we only see
// the last 8 events in our subscribers
let expected: Vec<i32> = (92..100).into_iter().collect();
let expected: Vec<i32> = (92..100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down Expand Up @@ -271,7 +271,7 @@ mod test {
}
});

let expected: Vec<i32> = (1..=100).into_iter().collect();
let expected: Vec<i32> = (1..=100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down
4 changes: 4 additions & 0 deletions ballista/scheduler/src/scheduler_server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
pub fn session_manager(&self) -> SessionManager {
self.state.session_manager.clone()
}

pub async fn wait_drained(&self) {
self.state.task_manager.wait_drained().await;
}
}

pub fn timestamp_secs() -> u64 {
Expand Down
42 changes: 34 additions & 8 deletions ballista/scheduler/src/state/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::sync::{watch, RwLock};

use crate::scheduler_server::timestamp_millis;
use tracing::trace;
Expand Down Expand Up @@ -115,6 +115,8 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
// Cache for active jobs curated by this scheduler
active_job_cache: ActiveJobCache,
launcher: Arc<dyn TaskLauncher>,
drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -149,13 +151,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
codec: BallistaCodec<T, U>,
scheduler_id: String,
) -> Self {
Self {
Self::with_launcher(
state,
codec,
scheduler_id: scheduler_id.clone(),
active_job_cache: Arc::new(DashMap::new()),
launcher: Arc::new(DefaultTaskLauncher::new(scheduler_id)),
}
scheduler_id.clone(),
Arc::new(DefaultTaskLauncher::new(scheduler_id)),
)
}

#[allow(dead_code)]
Expand All @@ -165,12 +166,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
scheduler_id: String,
launcher: Arc<dyn TaskLauncher>,
) -> Self {
let (drained, check_drained) = watch::channel(());

Self {
state,
codec,
scheduler_id,
active_job_cache: Arc::new(DashMap::new()),
launcher,
drained: Arc::new(drained),
check_drained,
}
}

Expand Down Expand Up @@ -690,9 +695,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
self.active_job_cache
let removed = self
.active_job_cache
.remove(job_id)
.map(|value| value.1.execution_graph)
.map(|value| value.1.execution_graph);

if self.get_active_job_count() == 0 {
self.drained.send_replace(());
}

removed
}

/// Generate a new random Job ID
Expand Down Expand Up @@ -721,6 +733,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
}
});
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();

loop {
if self.get_active_job_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}
}

pub struct JobOverview {
Expand Down
2 changes: 0 additions & 2 deletions ballista/scheduler/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ pub fn default_task_runner() -> impl TaskRunner {
};

let partitions: Vec<ShuffleWritePartition> = (0..partitions)
.into_iter()
.map(|i| ShuffleWritePartition {
partition_id: i as u64,
path: String::default(),
Expand Down Expand Up @@ -410,7 +409,6 @@ impl SchedulerTest {
let runner = runner.unwrap_or_else(|| Arc::new(default_task_runner()));

let executors: HashMap<String, VirtualExecutor> = (0..num_executors)
.into_iter()
.map(|i| {
let id = format!("virtual-executor-{i}");
let executor = VirtualExecutor {
Expand Down