Skip to content

Commit

Permalink
feat(katana-tasks): task manager (#2318)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored Aug 20, 2024
1 parent 4f76fd7 commit bb7d7df
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 9 deletions.
44 changes: 36 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ strum_macros = "0.25"
tempfile = "3.9.0"
test-log = "0.2.11"
thiserror = "1.0.32"
tokio = { version = "1.32.0", features = [ "full" ] }
tokio = { version = "1.39.2", features = [ "full" ] }
toml = "0.8"
tower = "0.4.13"
tower-http = "0.4.4"
Expand Down
3 changes: 3 additions & 0 deletions crates/katana/tasks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ futures.workspace = true
rayon.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-metrics = "0.3.1"
tokio-util = { version = "0.7.11", features = [ "rt" ] }
tracing.workspace = true
7 changes: 7 additions & 0 deletions crates/katana/tasks/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#![cfg_attr(not(test), warn(unused_crate_dependencies))]

mod manager;
mod task;

use std::any::Any;
use std::future::Future;
use std::panic::{self, AssertUnwindSafe};
Expand All @@ -6,7 +11,9 @@ use std::sync::Arc;
use std::task::Poll;

use futures::channel::oneshot;
pub use manager::*;
use rayon::ThreadPoolBuilder;
pub use task::*;
use tokio::runtime::Handle;
use tokio::task::JoinHandle;

Expand Down
175 changes: 175 additions & 0 deletions crates/katana/tasks/src/manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use std::future::Future;

use tokio::runtime::Handle;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;

use crate::task::{TaskBuilder, TaskResult};

pub type TaskHandle<T> = JoinHandle<TaskResult<T>>;

/// Usage for this task manager is mainly to spawn tasks that can be cancelled, and captures
/// panicked tasks (which in the context of the task manager - a critical task) for graceful
/// shutdown.
#[derive(Debug, Clone)]
pub struct TaskManager {
/// A handle to the Tokio runtime.
handle: Handle,
/// Keep track of currently running tasks.
tracker: TaskTracker,
/// Used to cancel all running tasks.
///
/// This is passed to all the tasks spawned by the manager.
pub(crate) on_cancel: CancellationToken,
}

impl TaskManager {
/// Create a new [`TaskManager`] from the given Tokio runtime handle.
pub fn new(handle: Handle) -> Self {
Self { handle, tracker: TaskTracker::new(), on_cancel: CancellationToken::new() }
}

pub fn current() -> Self {
Self::new(Handle::current())
}

pub fn spawn<F>(&self, fut: F) -> TaskHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.spawn_inner(fut)
}

/// Wait until all spawned tasks are completed.
pub async fn wait(&self) {
// need to close the tracker first before waiting
let _ = self.tracker.close();
self.tracker.wait().await;
// reopen the tracker for spawning future tasks
let _ = self.tracker.reopen();
}

/// Consumes the manager and wait until all tasks are finished, either due to completion or
/// cancellation.
pub async fn wait_shutdown(self) {
// need to close the tracker first before waiting
let _ = self.tracker.close();
let _ = self.on_cancel.cancelled().await;
self.tracker.wait().await;
}

/// Return the handle to the Tokio runtime that the manager is associated with.
pub fn handle(&self) -> &Handle {
&self.handle
}

/// Returns a new [`TaskBuilder`] for building a task to be spawned on this manager.
pub fn build_task(&self) -> TaskBuilder<'_> {
TaskBuilder::new(self)
}

fn spawn_inner<F>(&self, task: F) -> TaskHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let task = self.make_cancellable(task);
let task = self.tracker.track_future(task);
self.handle.spawn(task)
}

fn make_cancellable<F>(&self, fut: F) -> impl Future<Output = TaskResult<F::Output>>
where
F: Future,
{
let ct = self.on_cancel.clone();
async move {
tokio::select! {
_ = ct.cancelled() => {
TaskResult::Cancelled
},
res = fut => {
TaskResult::Completed(res)
},
}
}
}
}

impl Drop for TaskManager {
fn drop(&mut self) {
self.on_cancel.cancel();
}
}

#[cfg(test)]
mod tests {
use futures::future;
use tokio::time::{self, Duration};

use super::*;

#[tokio::test]
async fn normal_tasks() {
let manager = TaskManager::current();

manager.spawn(time::sleep(Duration::from_secs(1)));
manager.spawn(time::sleep(Duration::from_secs(1)));
manager.spawn(time::sleep(Duration::from_secs(1)));

// 3 tasks should be spawned on the manager
assert_eq!(manager.tracker.len(), 3);

// wait until all task spawned to the manager have been completed
manager.wait().await;

assert!(
!manager.on_cancel.is_cancelled(),
"cancellation signal shouldn't be sent on normal task completion"
)
}

#[tokio::test]
async fn task_with_graceful_shutdown() {
let manager = TaskManager::current();

// mock long running normal task and a task with graceful shutdown
manager.build_task().spawn(async {
loop {
time::sleep(Duration::from_secs(1)).await
}
});

manager.build_task().spawn(async {
loop {
time::sleep(Duration::from_secs(1)).await
}
});

// assert that 2 tasks should've been spawned
assert_eq!(manager.tracker.len(), 2);

// Spawn a task with graceful shuwdown that finish immediately.
// The long running task should be cancelled due to the graceful shutdown.
manager.build_task().graceful_shutdown().spawn(future::ready(()));

// wait until all task spawned to the manager have been completed
manager.wait_shutdown().await;
}

#[tokio::test]
async fn critical_task_implicit_graceful_shutdown() {
let manager = TaskManager::current();
manager.build_task().critical().spawn(future::ready(()));
manager.wait_shutdown().await;
}

#[tokio::test]
async fn critical_task_graceful_shudown_on_panicked() {
let manager = TaskManager::current();
manager.build_task().critical().spawn(async { panic!("panicking") });
manager.wait_shutdown().await;
}
}
Loading

0 comments on commit bb7d7df

Please sign in to comment.