Skip to content

Commit

Permalink
ReduceTask choose device stream in round-robin fashion to avoid dynam…
Browse files Browse the repository at this point in the history
…ic decisions and need for locking
  • Loading branch information
evaleev committed Sep 28, 2023
1 parent dfa3f76 commit 089dcc6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
16 changes: 13 additions & 3 deletions src/TiledArray/external/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,14 +798,24 @@ class Env {
};

namespace detail {
// in a madness device task point to its local optional stream to use by
// madness_task_stream_opt; set to nullptr after task callable finished
inline std::optional<Stream>*& madness_task_stream_opt_ptr_accessor() {
static thread_local std::optional<Stream>* stream_opt_ptr = nullptr;
return stream_opt_ptr;
}

inline std::optional<Stream>& tls_stream_opt_accessor() {
static thread_local std::optional<Stream> stream_opt =
device::Env::instance()->stream(0);
return stream_opt;
}

inline std::optional<Stream>& madness_task_stream_opt_accessor() {
TA_ASSERT(madness_task_stream_opt_ptr_accessor() != nullptr);
return *madness_task_stream_opt_ptr_accessor();
if (madness_task_stream_opt_ptr_accessor() != nullptr) // in a device task?
return *madness_task_stream_opt_ptr_accessor();
else
return tls_stream_opt_accessor();
}
} // namespace detail

Expand Down Expand Up @@ -867,7 +877,7 @@ inline void cancel_madness_task_sync() {
/// associated with Range \p range
template <typename Range>
device::Stream stream_for(const Range& range) {
auto stream_opt = madness_task_current_stream();
const auto stream_opt = madness_task_current_stream();
if (!stream_opt) {
auto stream_ord =
range.offset() % device::Env::instance()->num_streams_total();
Expand Down
27 changes: 19 additions & 8 deletions src/TiledArray/reduce_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <TiledArray/external/device.h>
#include <TiledArray/tensor/type_traits.h>
#include <TiledArray/util/time.h>
inline std::atomic<std::int64_t> global_reduce_task_counter(0);
#endif

namespace TiledArray {
Expand Down Expand Up @@ -596,15 +597,15 @@ class ReduceTask {
/// Reduce two reduction arguments
void reduce_object_object(const ReduceObject* object1,
const ReduceObject* object2) {
// Construct an empty result object
auto result = std::make_shared<result_type>(op_());

#ifdef TILEDARRAY_HAS_DEVICE
TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
device::detail::madness_task_stream_opt_ptr_accessor() = &stream_;
#endif

// Construct an empty result object
auto result = std::make_shared<result_type>(op_());

// Reduce the two arguments
op_(*result, object1->arg());
op_(*result, object2->arg());
Expand Down Expand Up @@ -692,9 +693,9 @@ class ReduceTask {
Future<result_type> result_; ///< The result of the reduction task
madness::Spinlock lock_; ///< Task lock
madness::CallbackInterface* callback_; ///< The completion callback
int task_id_; ///< Task id
std::int64_t task_id_; ///< Task id
#ifdef TILEDARRAY_HAS_DEVICE
std::optional<device::Stream> stream_;
std::optional<device::Stream> stream_; // round-robined by task_id
#endif

public:
Expand All @@ -706,7 +707,7 @@ class ReduceTask {
/// has completed
/// \param task_id the task id (for debugging)
ReduceTaskImpl(World& world, opT op, madness::CallbackInterface* callback,
int task_id = -1)
std::int64_t task_id = -1)
: madness::TaskInterface(1, TaskAttributes::hipri()),
world_(world),
op_(op),
Expand All @@ -715,7 +716,16 @@ class ReduceTask {
result_(),
lock_(),
callback_(callback),
task_id_(task_id) {}
task_id_(task_id) {
#ifdef TILEDARRAY_HAS_DEVICE
if (task_id_ == -1) {
task_id_ = global_reduce_task_counter++;
const std::size_t stream_ord =
task_id_ % device::Env::instance()->num_streams_total();
stream_ = device::Env::instance()->stream(stream_ord);
}
#endif
}

virtual ~ReduceTaskImpl() {}

Expand Down Expand Up @@ -780,7 +790,8 @@ class ReduceTask {
/// this task is complete
/// \param task_id the task id (for debugging)
ReduceTask(World& world, const opT& op = opT(),
madness::CallbackInterface* callback = nullptr, int task_id = -1)
madness::CallbackInterface* callback = nullptr,
std::int64_t task_id = -1)
: pimpl_(new ReduceTaskImpl(world, op, callback, task_id)), count_(0ul) {}

/// Move constructor
Expand Down

0 comments on commit 089dcc6

Please sign in to comment.