Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add maximum simultaneous tasks support to TaskContainer #464

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ci/conda/recipes/libmrc/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ requirements:
- gtest =1.14
- libhwloc =2.9.2
- librmm {{ rapids_version }}
- nlohmann_json =3.9
- nlohmann_json =3.11
- pybind11-abi # See: https://conda-forge.org/docs/maintainer/knowledge_base.html#pybind11-abi-constraints
- pybind11-stubgen =0.10
- python {{ python }}
Expand Down Expand Up @@ -90,12 +90,12 @@ outputs:
- libgrpc =1.59
- libhwloc =2.9.2
- librmm {{ rapids_version }}
- nlohmann_json =3.9
- nlohmann_json =3.11
- ucx =1.15
run:
# Manually add any packages necessary for run that do not have run_exports. Keep sorted!
- cuda-version {{ cuda_version }}.*
- nlohmann_json =3.9
- nlohmann_json =3.11
- ucx =1.15
- cuda-cudart
- boost-cpp =1.84
Expand Down
2 changes: 2 additions & 0 deletions ci/iwyu/mappings.imp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# boost
{ "include": ["@<boost/fiber/future/detail/.*>", "private", "<boost/fiber/future/future.hpp>", "public"] },
{ "include": ["@<boost/algorithm/string/detail/.*>", "private", "<boost/algorithm/string.hpp>", "public"] },

# cuda
{ "include": ["<cuda_runtime_api.h>", "private", "<cuda_runtime.h>", "public"] },
Expand All @@ -33,6 +34,7 @@
{ "symbol": ["@grpc::.*", "private", "<grpcpp/grpcpp.h>", "public"] },

# nlohmann json
{ "include": ["<nlohmann/json_fwd.hpp>", "public", "<nlohmann/json.hpp>", "public"] },
{ "include": ["<nlohmann/detail/iterators/iter_impl.hpp>", "private", "<nlohmann/json.hpp>", "public"] },
{ "include": ["<nlohmann/detail/iterators/iteration_proxy.hpp>", "private", "<nlohmann/json.hpp>", "public"] },
{ "include": ["<nlohmann/detail/json_ref.hpp>", "private", "<nlohmann/json.hpp>", "public"] },
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-121_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies:
- libxml2=2.11.6
- llvmdev=16
- ninja=1.11
- nlohmann_json=3.9
- nlohmann_json=3.11
- numactl-libs-cos7-x86_64
- numpy=1.24
- pkg-config=0.29
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/ci_cuda-121_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
- librmm=24.02
- libxml2=2.11.6
- ninja=1.11
- nlohmann_json=3.9
- nlohmann_json=3.11
- numactl-libs-cos7-x86_64
- pkg-config=0.29
- pre-commit
Expand Down
12 changes: 11 additions & 1 deletion cpp/mrc/include/mrc/coroutines/task_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <memory>
#include <mutex>
#include <optional>
#include <queue>
#include <vector>

namespace mrc::coroutines {
Expand All @@ -60,7 +61,7 @@ class TaskContainer
* @param e Tasks started in the container are scheduled onto this executor. For tasks created
* from a coro::io_scheduler, this would usually be that coro::io_scheduler instance.
*/
TaskContainer(std::shared_ptr<Scheduler> e);
TaskContainer(std::shared_ptr<Scheduler> e, std::size_t max_simultaneous_tasks = -1);

TaskContainer(const TaskContainer&) = delete;
TaskContainer(TaskContainer&&) = delete;
Expand Down Expand Up @@ -138,6 +139,11 @@ class TaskContainer
*/
auto gc_internal() -> std::size_t;

/**
* Starts the next taks in the queue.
*/
void start_next_task();

/**
* Encapsulate the users tasks in a cleanup task which marks itself for deletion upon
* completion. Simply co_await the users task until its completed and then mark the given
Expand Down Expand Up @@ -166,6 +172,10 @@ class TaskContainer
std::shared_ptr<Scheduler> m_scheduler_lifetime{nullptr};
/// This is used internally since io_scheduler cannot pass itself in as a shared_ptr.
Scheduler* m_scheduler{nullptr};
/// tasks to be processed in order of start
std::queue<decltype(m_tasks.end())> m_next_tasks;
/// maximum number of tasks to be run simultaneously
int32_t m_max_simultaneous_tasks;
cwharris marked this conversation as resolved.
Show resolved Hide resolved

friend Scheduler;
};
Expand Down
7 changes: 7 additions & 0 deletions cpp/mrc/include/mrc/coroutines/test_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class TestScheduler : public Scheduler
*/
mrc::coroutines::Task<> yield_until(std::chrono::time_point<std::chrono::steady_clock> time) override;

/**
* Returns the time according to the scheduler. Time may be progressed by resume_next, resume_for, and resume_until.
*
* @return the current time according to the scheduler.
*/
std::chrono::time_point<std::chrono::steady_clock> time();

/**
* Immediately resumes the next-in-queue coroutine handle.
*
Expand Down
64 changes: 46 additions & 18 deletions cpp/mrc/src/public/coroutines/task_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@

namespace mrc::coroutines {

TaskContainer::TaskContainer(std::shared_ptr<Scheduler> e) :
TaskContainer::TaskContainer(std::shared_ptr<Scheduler> e, std::size_t max_simultaneous_tasks) :
m_scheduler_lifetime(std::move(e)),
m_scheduler(m_scheduler_lifetime.get())
m_scheduler(m_scheduler_lifetime.get()),
m_max_simultaneous_tasks(max_simultaneous_tasks)
{
if (m_scheduler_lifetime == nullptr)
{
Expand All @@ -53,20 +54,26 @@ auto TaskContainer::start(Task<void>&& user_task, GarbageCollectPolicy cleanup)
{
m_size.fetch_add(1, std::memory_order::relaxed);
cwharris marked this conversation as resolved.
Show resolved Hide resolved

std::scoped_lock lk{m_mutex};

if (cleanup == GarbageCollectPolicy::yes)
{
gc_internal();
}
std::scoped_lock lk{m_mutex};

// Store the task inside a cleanup task for self deletion.
auto pos = m_tasks.emplace(m_tasks.end(), std::nullopt);
auto task = make_cleanup_task(std::move(user_task), pos);
*pos = std::move(task);
if (cleanup == GarbageCollectPolicy::yes)
{
gc_internal();
}

// Start executing from the cleanup task to schedule the user's task onto the thread pool.
pos->value().resume();
// Store the task inside a cleanup task for self deletion.
auto pos = m_tasks.emplace(m_tasks.end(), std::nullopt);
auto task = make_cleanup_task(std::move(user_task), pos);
*pos = std::move(task);

m_next_tasks.push(pos);
}

if (m_max_simultaneous_tasks <= 0 or m_size <= m_max_simultaneous_tasks)
cwharris marked this conversation as resolved.
Show resolved Hide resolved
{
start_next_task();
}
}

auto TaskContainer::garbage_collect() -> std::size_t
Expand Down Expand Up @@ -133,6 +140,19 @@ auto TaskContainer::gc_internal() -> std::size_t
return deleted;
}

void TaskContainer::start_next_task()
{
auto pos = [this]() {
std::scoped_lock lk{m_mutex};
auto pos = m_next_tasks.front();
m_next_tasks.pop();
return pos;
}();

// Start executing from the cleanup task to schedule the user's task onto the thread pool.
pos->value().resume();
cwharris marked this conversation as resolved.
Show resolved Hide resolved
cwharris marked this conversation as resolved.
Show resolved Hide resolved
}

auto TaskContainer::make_cleanup_task(Task<void> user_task, task_position_t pos) -> Task<void>
{
// Immediately move the task onto the executor.
Expand All @@ -155,11 +175,19 @@ auto TaskContainer::make_cleanup_task(Task<void> user_task, task_position_t pos)
LOG(ERROR) << "coro::task_container user_task had unhandle exception, not derived from std::exception.\n";
}

std::scoped_lock lk{m_mutex};
m_tasks_to_delete.push_back(pos);
// This has to be done within scope lock to make sure this coroutine task completes before the
// task container object destructs -- if it was waiting on .empty() to become true.
m_size.fetch_sub(1, std::memory_order::relaxed);
{
std::scoped_lock lk{m_mutex};
m_tasks_to_delete.push_back(pos);
// This has to be done within scope lock to make sure this coroutine task completes before the
// task container object destructs -- if it was waiting on .empty() to become true.
m_size.fetch_sub(1, std::memory_order::relaxed);
}

if (not m_next_tasks.empty())
{
start_next_task();
}
cwharris marked this conversation as resolved.
Show resolved Hide resolved

co_return;
}

Expand Down
12 changes: 12 additions & 0 deletions cpp/mrc/src/public/coroutines/test_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mrc/coroutines/test_scheduler.hpp"

#include <chrono>
#include <compare>

namespace mrc::coroutines {
Expand Down Expand Up @@ -56,8 +57,14 @@ mrc::coroutines::Task<> TestScheduler::yield_until(std::chrono::time_point<std::
co_return co_await TestScheduler::Operation{this, time};
}

std::chrono::time_point<std::chrono::steady_clock> TestScheduler::time() {
return m_time;
}

bool TestScheduler::resume_next()
{
using namespace std::chrono_literals;

if (m_queue.empty())
{
return false;
Expand All @@ -69,6 +76,11 @@ bool TestScheduler::resume_next()

m_time = handle.second;

if (not m_queue.empty())
{
m_time = m_queue.top().second;
}

handle.first.resume();

return true;
Expand Down
48 changes: 48 additions & 0 deletions cpp/mrc/tests/coroutines/test_task_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,57 @@
* limitations under the License.
*/

#include "mrc/coroutines/scheduler.hpp"
#include "mrc/coroutines/sync_wait.hpp"
#include "mrc/coroutines/task.hpp"
#include "mrc/coroutines/task_container.hpp"
#include "mrc/coroutines/test_scheduler.hpp"
#include "mrc/coroutines/when_all.hpp"

#include <gtest/gtest.h>

#include <memory>
#include <thread>

class TestCoroTaskContainer : public ::testing::Test
{};

TEST_F(TestCoroTaskContainer, LifeCycle) {}

TEST_F(TestCoroTaskContainer, MaxSimultaneousTasks)
{
using namespace std::chrono_literals;

auto on = std::make_shared<mrc::coroutines::TestScheduler>();
auto task_container = mrc::coroutines::TaskContainer(on, 2);

auto start_time = on->time();

std::vector<std::chrono::time_point<std::chrono::steady_clock>> execution_times;

auto delay = [](std::shared_ptr<mrc::coroutines::TestScheduler> on,
std::vector<std::chrono::time_point<std::chrono::steady_clock>>& execution_times)
-> mrc::coroutines::Task<> {
co_await on->yield_for(100ms);
execution_times.emplace_back(on->time());
};

for (auto i = 0; i < 4; i++)
{
task_container.start(delay(on, execution_times));
}

auto task = task_container.garbage_collect_and_yield_until_empty();

task.resume();

while (on->resume_next()) {}

mrc::coroutines::sync_wait(task);

ASSERT_EQ(execution_times.size(), 4);
ASSERT_EQ(execution_times[0], start_time + 100ms);
ASSERT_EQ(execution_times[1], start_time + 100ms);
ASSERT_EQ(execution_times[2], start_time + 200ms);
ASSERT_EQ(execution_times[3], start_time + 200ms);
}
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dependencies:
- librmm=24.02
- libxml2=2.11.6 # 2.12 has a bug preventing round-trip serialization in hwloc
- ninja=1.11
- nlohmann_json=3.9
- nlohmann_json=3.11
- numactl-libs-cos7-x86_64
- pkg-config=0.29
- pybind11-stubgen=0.10
Expand Down
1 change: 1 addition & 0 deletions python/mrc/_pymrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_library(pymrc
src/utilities/acquire_gil.cpp
src/utilities/deserializers.cpp
src/utilities/function_wrappers.cpp
src/utilities/json_values.cpp
src/utilities/object_cache.cpp
src/utilities/object_wrappers.cpp
src/utilities/serializers.cpp
Expand Down
12 changes: 1 addition & 11 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,6 @@ class AsyncioRunnable : public AsyncSink<InputT>,
std::shared_ptr<mrc::coroutines::Scheduler> on) = 0;

std::stop_source m_stop_source;

/**
* @brief A semaphore used to control the number of outstanding operations. Acquire one before
* beginning a task, and release it when finished.
*/
std::counting_semaphore<8> m_task_tickets{8};
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -282,14 +276,12 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx)
template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler)
{
coroutines::TaskContainer outstanding_tasks(scheduler);
coroutines::TaskContainer outstanding_tasks(scheduler, 8);

ExceptionCatcher catcher{};

while (not m_stop_source.stop_requested() and not catcher.has_exception())
{
m_task_tickets.acquire();

InputT data;

auto read_status = co_await this->read_async(data);
Expand Down Expand Up @@ -335,8 +327,6 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
{
catcher.push_exception(std::current_exception());
}

m_task_tickets.release();
}

template <typename InputT, typename OutputT>
Expand Down
19 changes: 17 additions & 2 deletions python/mrc/_pymrc/include/pymrc/types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,9 +21,12 @@

#include "mrc/segment/object.hpp"

#include <nlohmann/json_fwd.hpp>
#include <rxcpp/rx.hpp>

#include <functional>
#include <functional> // for function
#include <map>
#include <string>

namespace mrc::pymrc {

Expand All @@ -37,4 +40,16 @@ using PyNode = mrc::segment::ObjectProperties;
using PyObjectOperateFn = std::function<PyObjectObservable(PyObjectObservable source)>;
// NOLINTEND(readability-identifier-naming)

using python_map_t = std::map<std::string, pybind11::object>;

/**
* @brief Unserializable handler function type, invoked by `cast_from_pyobject` when an object cannot be serialized to
* JSON. Implementations should return a valid json object, or throw an exception if the object cannot be serialized.
* @param source : pybind11 object
* @param path : string json path to object
* @return nlohmann::json.
*/
using unserializable_handler_fn_t =
std::function<nlohmann::json(const pybind11::object& /* source*/, const std::string& /* path */)>;

} // namespace mrc::pymrc
Loading
Loading