diff --git a/src/TiledArray/external/device.h b/src/TiledArray/external/device.h index 133bb11c56..dcf286c443 100644 --- a/src/TiledArray/external/device.h +++ b/src/TiledArray/external/device.h @@ -798,14 +798,24 @@ class Env { }; namespace detail { +// in a madness device task point to its local optional stream to use by +// madness_task_stream_opt; set to nullptr after task callable finished inline std::optional*& madness_task_stream_opt_ptr_accessor() { static thread_local std::optional* stream_opt_ptr = nullptr; return stream_opt_ptr; } +inline std::optional& tls_stream_opt_accessor() { + static thread_local std::optional stream_opt = + device::Env::instance()->stream(0); + return stream_opt; +} + inline std::optional& madness_task_stream_opt_accessor() { - TA_ASSERT(madness_task_stream_opt_ptr_accessor() != nullptr); - return *madness_task_stream_opt_ptr_accessor(); + if (madness_task_stream_opt_ptr_accessor() != nullptr) // in a device task? + return *madness_task_stream_opt_ptr_accessor(); + else + return tls_stream_opt_accessor(); } } // namespace detail @@ -867,7 +877,7 @@ inline void cancel_madness_task_sync() { /// associated with Range \p range template device::Stream stream_for(const Range& range) { - auto stream_opt = madness_task_current_stream(); + const auto stream_opt = madness_task_current_stream(); if (!stream_opt) { auto stream_ord = range.offset() % device::Env::instance()->num_streams_total(); diff --git a/src/TiledArray/reduce_task.h b/src/TiledArray/reduce_task.h index 2a5813ff10..7d8924b0c3 100644 --- a/src/TiledArray/reduce_task.h +++ b/src/TiledArray/reduce_task.h @@ -29,6 +29,7 @@ #include #include #include +inline std::atomic global_reduce_task_counter(0); #endif namespace TiledArray { @@ -596,15 +597,15 @@ class ReduceTask { /// Reduce two reduction arguments void reduce_object_object(const ReduceObject* object1, const ReduceObject* object2) { - // Construct an empty result object - auto result = std::make_shared(op_()); - #ifdef TILEDARRAY_HAS_DEVICE TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() == nullptr); device::detail::madness_task_stream_opt_ptr_accessor() = &stream_; #endif + // Construct an empty result object + auto result = std::make_shared(op_()); + // Reduce the two arguments op_(*result, object1->arg()); op_(*result, object2->arg()); @@ -692,9 +693,9 @@ class ReduceTask { Future result_; ///< The result of the reduction task madness::Spinlock lock_; ///< Task lock madness::CallbackInterface* callback_; ///< The completion callback - int task_id_; ///< Task id + std::int64_t task_id_; ///< Task id #ifdef TILEDARRAY_HAS_DEVICE - std::optional stream_; + std::optional stream_; // round-robined by task_id #endif public: @@ -706,7 +707,7 @@ class ReduceTask { /// has completed /// \param task_id the task id (for debugging) ReduceTaskImpl(World& world, opT op, madness::CallbackInterface* callback, - int task_id = -1) + std::int64_t task_id = -1) : madness::TaskInterface(1, TaskAttributes::hipri()), world_(world), op_(op), @@ -715,7 +716,16 @@ class ReduceTask { result_(), lock_(), callback_(callback), - task_id_(task_id) {} + task_id_(task_id) { +#ifdef TILEDARRAY_HAS_DEVICE + if (task_id_ == -1) { + task_id_ = global_reduce_task_counter++; + const std::size_t stream_ord = + task_id_ % device::Env::instance()->num_streams_total(); + stream_ = device::Env::instance()->stream(stream_ord); + } +#endif + } virtual ~ReduceTaskImpl() {} @@ -780,7 +790,8 @@ class ReduceTask { /// this task is complete /// \param task_id the task id (for debugging) ReduceTask(World& world, const opT& op = opT(), - madness::CallbackInterface* callback = nullptr, int task_id = -1) + madness::CallbackInterface* callback = nullptr, + std::int64_t task_id = -1) : pimpl_(new ReduceTaskImpl(world, op, callback, task_id)), count_(0ul) {} /// Move constructor