From 99e83dfb5035cf52712a2ead8c069da9aca2e1ec Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 7 Jun 2021 12:30:18 +0900 Subject: [PATCH 1/8] Add async GPU depency Engine Signed-off-by: Serge Panev --- include/mxnet/base.h | 6 +- include/mxnet/c_api.h | 2 +- include/mxnet/engine.h | 121 +++++++++++- include/mxnet/storage.h | 19 +- src/c_api/c_api.cc | 13 +- src/common/object_pool.h | 4 +- src/engine/engine.cc | 21 ++- src/engine/naive_engine.cc | 22 ++- src/engine/stream_manager.h | 24 ++- src/engine/threaded_engine.cc | 229 ++++++++++++++++++++++- src/engine/threaded_engine.h | 43 +++-- src/engine/threaded_engine_perdevice.cc | 58 +++++- src/engine/threaded_engine_pooled.cc | 29 ++- src/imperative/imperative_utils.h | 49 +++-- src/io/batchify.cc | 2 +- src/io/dataset.cc | 4 +- src/kvstore/comm.h | 23 ++- src/kvstore/gradient_compression.cc | 4 - src/kvstore/kvstore_dist.h | 21 ++- src/kvstore/kvstore_dist_server.h | 10 +- src/kvstore/kvstore_local.h | 7 +- src/kvstore/p3store_dist.h | 13 +- src/ndarray/ndarray.cc | 196 ++++++++++++------- src/operator/custom/ndarray_op.cc | 10 +- src/operator/operator_util.cc | 78 +++----- src/resource.cc | 116 ++++++------ src/storage/gpu_device_storage.h | 8 + src/storage/pooled_storage_manager.h | 47 +++-- src/storage/storage.cc | 2 +- tests/cpp/engine/threaded_engine_test.cc | 38 +++- tests/python/gpu/test_gluon_gpu.py | 86 --------- tests/python/gpu/test_operator_gpu.py | 73 -------- 32 files changed, 920 insertions(+), 458 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 12b083c67576..aecd7a300394 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -347,10 +347,10 @@ struct RunContext { * \brief the auxiliary stream of the device, can be nullptr or Stream* in GPU mode */ void *aux_stream; - /*! - * \brief indicator of whether this execution is run in bulk mode + /*! + * \brief pointer to the cuda event pool used by the dependecy engine */ - bool is_bulk; + void *event_pool = nullptr; /*! * \brief get mshadow stream from Context * \return the mshadow stream diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6e668a48883c..0aff74772c47 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -110,7 +110,7 @@ typedef const void *EngineFnPropertyHandle; typedef void *EngineVarHandle; /*! \brief Engine asynchronous operation */ -typedef void (*EngineAsyncFunc)(void*, void*, void*); +typedef void (*EngineAsyncFunc)(void*, void*, void*, void*); /*! \brief Engine synchronous operation */ typedef void (*EngineSyncFunc)(void*, void*); /*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */ diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 1a9582edd518..e37c0e646e29 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -29,6 +29,7 @@ #include #include #endif +#include #include #include "./base.h" @@ -39,6 +40,72 @@ class Engine; /*! \brief namespace of engine internal types. */ namespace engine { +#if MXNET_USE_CUDA +/* \brief The class wrapping CUDA event with timing disabled. */ +class CUDAEvent final { + public: + explicit CUDAEvent(Context const& ctx); + +CUDAEvent(CUDAEvent&& other) + : event_(other.event_), dev_id_(other.dev_id_) { + other.event_ = nullptr; + } + + CUDAEvent(const CUDAEvent& other) = delete; + void operator=(const CUDAEvent& other) = delete; + + ~CUDAEvent(); + + inline std::weak_ptr GetEvent() noexcept { + return event_; + } + private: + std::shared_ptr event_; + int dev_id_; +}; + +class CUDAEventPool final { + public: + explicit CUDAEventPool(Context const& ctx) : counter_(0) { + for (size_t i = 0; i < kPoolSize; ++i) { + events_.emplace_back(ctx); + } + } + + inline std::weak_ptr GetEvent(size_t i) noexcept { + return events_.at(i).GetEvent(); + } + + inline std::pair, uint64_t> GetNextEvent() noexcept { + int c = counter_++; + return {events_.at((c) % kPoolSize).GetEvent(), c}; + } + + inline uint64_t GetCounterValue() noexcept { + return counter_.load(); + } + private: + static constexpr size_t kPoolSize = 64; + std::vector events_; + std::atomic counter_; +}; + +/*! \brief full event info for the sync object.*/ +struct EventInfo { + std::weak_ptr event; + cudaStream_t stream; + uint64_t pool_index; +}; +/*! \brief struct containing cuda events and variables needed for the dependencies.*/ +struct SyncObject { + // vector can carry multiple reader events + std::vector reader_events; + // vector should carry only 1 writer event + std::vector writer_event; + std::mutex mutex; +}; +#endif + /*! \brief base class of engine variables.*/ struct Var { virtual size_t version() { @@ -57,6 +124,12 @@ struct Var { * is modified, the version number is incremented by 1. */ size_t version_{0}; +#if MXNET_USE_CUDA + /*! + * \brief struct containing cuda events and variables needed for the dependencies. + */ + SyncObject sync_object; +#endif }; // struct Var /*! \brief Internal representation of operator. */ @@ -65,6 +138,29 @@ struct Opr; typedef Var* VarHandle; /*! \brief Operator pointer type, usually hold by user.*/ typedef Opr* OprHandle; +/*! + * \brief OnStart callback to the engine, + * called by AsyncFn before the action + */ +class CallbackOnStart { + public: + // use implicit copy and assign + /*! \brief involve the callback */ + inline void operator()(const dmlc::Error* error = nullptr) const { + if (callback_ != nullptr) + (*callback_)(engine_, param_, error); + } + + private: + /*! \brief engine can see content of callback */ + friend class ::mxnet::Engine; + /*! \brief the real callback */ + void (*callback_)(Engine *, void *, const dmlc::Error *); + /*! \brief the engine class passed to callback */ + Engine* engine_; + /*! \brief the parameter set on callback */ + void* param_; +}; /*! * \brief OnComplete Callback to the engine, * called by AsyncFn when action completes @@ -115,12 +211,14 @@ enum class FnProperty { */ class MXNET_API Engine { public: + /*! \brief on start*/ + typedef engine::CallbackOnStart CallbackOnStart; /*! \brief callback on complete*/ typedef engine::CallbackOnComplete CallbackOnComplete; /*! \brief Synchronous operation to pass to engine. */ typedef std::function SyncFn; /*! \brief Asynchronous operation to pass to engine. */ - typedef std::function AsyncFn; + typedef std::function AsyncFn; /*! \brief Variable pointer */ typedef engine::VarHandle VarHandle; /*! \brief Operator pointer */ @@ -247,7 +345,7 @@ class MXNET_API Engine { * * \return A shared pointer to Engine singleton. */ - static std::shared_ptr _GetSharedRef(); + static const std::shared_ptr &_GetSharedRef(); /*! * \brief Push an synchronous operation to the engine. * \param exec_fn Execution function that executes the operation. @@ -266,12 +364,29 @@ class MXNET_API Engine { FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) { - this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { + this->PushAsync([exec_fn](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + on_start(); exec_fn(ctx); on_complete(); }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); } + /*! + * \brief factory function to create OnStart callback. + * \param callback th static callback function. + * \param param the paramter passed to callback. + */ + inline CallbackOnStart CreateOnStart( + void (*callback)(Engine *, void *, const dmlc::Error *), void *param) { + CallbackOnStart ret; + ret.callback_ = callback; + ret.engine_ = this; + ret.param_ = param; + return ret; + } + /*! * \brief factory function to create OnComplete callback. * \param callback th static callback function. diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index a72da9e83b61..46e4e3326790 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -26,6 +26,7 @@ #include #include +#include #include "./base.h" namespace mxnet { @@ -38,6 +39,17 @@ namespace mxnet { */ class Storage { public: + /*! + * \brief Storage sync object. + */ + struct SyncObj { +#if MXNET_USE_CUDA + /*! + * \brief All the events from the engine variable. + */ + std::vector> events; +#endif + }; /*! * \brief Storage handle. */ @@ -64,6 +76,11 @@ class Storage { */ std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR}; std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR}; + /*! + * \brief Used to pass events back and forth between the engine Var + * and the storage manager. + */ + SyncObj sync_obj; }; /*! * \brief Allocate a new contiguous memory for a given size. @@ -137,7 +154,7 @@ class Storage { * * \return A shared pointer to Storage singleton. */ - static std::shared_ptr _GetSharedRef(); + static const std::shared_ptr &_GetSharedRef(); private: std::mutex cpu_mutex_; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f73cc18ce5bc..4eb9713a8107 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3764,6 +3764,7 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, } using VarHandle = Engine::VarHandle; +using CallbackOnStart = Engine::CallbackOnStart; using CallbackOnComplete = Engine::CallbackOnComplete; void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) { @@ -3795,15 +3796,19 @@ int MXEnginePushAsync(EngineAsyncFunc async_func, Engine::AsyncFn exec_fn; if (deleter == nullptr) { - exec_fn = [async_func, func_param](RunContext rctx, CallbackOnComplete on_complete) { - async_func(&rctx, &on_complete, func_param); + exec_fn = [async_func, func_param](RunContext rctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + async_func(&rctx, &on_start, &on_complete, func_param); }; } else { // Wrap func_param in a shared_ptr with deleter such that deleter // will be called when the lambda goes out of scope. std::shared_ptr shared_func_param(func_param, deleter); - exec_fn = [async_func, shared_func_param](RunContext rctx, CallbackOnComplete on_complete) { - async_func(&rctx, &on_complete, shared_func_param.get()); + exec_fn = [async_func, shared_func_param](RunContext rctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + async_func(&rctx, &on_start,, &on_complete, shared_func_param.get()); }; } diff --git a/src/common/object_pool.h b/src/common/object_pool.h index 72ba3877f360..023e9d6f2df2 100644 --- a/src/common/object_pool.h +++ b/src/common/object_pool.h @@ -61,7 +61,7 @@ class ObjectPool { * \brief Get a shared ptr of the singleton instance of pool. * \return Shared pointer to the Object Pool. */ - static std::shared_ptr _GetSharedRef(); + static const std::shared_ptr &_GetSharedRef(); private: /*! @@ -170,7 +170,7 @@ ObjectPool* ObjectPool::Get() { } template -std::shared_ptr > ObjectPool::_GetSharedRef() { +const std::shared_ptr > &ObjectPool::_GetSharedRef() { static std::shared_ptr > inst_ptr(new ObjectPool()); return inst_ptr; } diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 1d236e11ef50..c491b5cf0b64 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -25,6 +25,7 @@ #include #include #include "./engine_impl.h" +#include "../common/cuda/utils.h" namespace mxnet { namespace engine { @@ -56,9 +57,27 @@ inline Engine* CreateEngine() { } return ret; } + +#if MXNET_USE_CUDA +CUDAEvent::CUDAEvent(Context const& ctx) : + event_(std::make_shared()), dev_id_(ctx.dev_id) { + cudaEvent_t ev; + common::cuda::DeviceStore device_store(dev_id_); + CUDA_CALL(cudaEventCreateWithFlags(&ev, cudaEventDisableTiming)); + *event_ = ev; +} + +CUDAEvent::~CUDAEvent() { + if (event_ && *event_ != nullptr) { + common::cuda::DeviceStore device_store(dev_id_); + CUDA_CALL(cudaEventSynchronize(*event_)); + CUDA_CALL(cudaEventDestroy(*event_)); + } +} +#endif } // namespace engine -std::shared_ptr Engine::_GetSharedRef() { +const std::shared_ptr &Engine::_GetSharedRef() { static std::shared_ptr sptr(engine::CreateEngine()); return sptr; } diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 0693574b04ea..76dd04249649 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -118,7 +118,9 @@ class NaiveEngine final : public Engine { NaiveOpr* opr = op->Cast(); opr->profiling = profiling && profiler->IsProfiling(profiler::Profiler::kSymbolic); this->PushAsync( - [&](RunContext ctx, CallbackOnComplete on_complete) { + [&](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { if (opr->profiling) { std::unique_ptr attrs; if (profiler->AggregateEnabled()) { @@ -128,7 +130,7 @@ class NaiveEngine final : public Engine { std::make_unique(opr->opr_name.c_str(), attrs.release()); opr->opr_profile->startForDevice(exec_ctx.dev_type, exec_ctx.dev_id); } - opr->fn(ctx, on_complete); + opr->fn(ctx, on_start, on_complete); if (opr->profiling) { opr->opr_profile->stop(); } @@ -156,6 +158,7 @@ class NaiveEngine final : public Engine { bool wait = false) override { std::promise promise; std::future future = promise.get_future(); + CallbackOnStart on_start = CreateOnStart(NaiveEngine::OnStart, &promise); CallbackOnComplete callback = CreateCallback(NaiveEngine::OnComplete, &promise); profiler::Profiler* profiler = profiler::Profiler::Get(); auto opr_deleter = [this](NaiveOpr* p) { this->DeleteOperator(p); }; @@ -189,12 +192,14 @@ class NaiveEngine final : public Engine { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]); } - exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback); + exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id]}, + on_start, + callback); #else LOG(FATAL) << "GPU is not enabled"; #endif } else { - exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback); + exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr}, on_start, callback); } future.wait(); // increment mutable var version @@ -209,7 +214,10 @@ class NaiveEngine final : public Engine { void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { NaiveVar* naive_var = NaiveVar::CastFromBase(var); this->PushAsync( - [delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable { + [delete_fn, naive_var](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) mutable { + on_start(); delete_fn(ctx); NaiveVar::Delete(naive_var); on_complete(); @@ -233,6 +241,10 @@ class NaiveEngine final : public Engine { } private: + // onstart + static void OnStart(Engine *engine, void *param, + const dmlc::Error* error) { + } // callback to oncomplete static void OnComplete(Engine* engine, void* param, const dmlc::Error* error) { static_cast*>(param)->set_value(); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 87df70ede310..f0cfb19fa7de 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -25,7 +25,9 @@ #include #include #include +#include #include +#include "./engine_impl.h" #include "../common/cuda/utils.h" namespace mxnet { @@ -55,6 +57,7 @@ class StreamManager { std::array, kNumGpus> gpu_aux_streams_; std::array*, kNumGpus> gpu_io_streams_; std::array gpu_cnt_; + std::array, kNumGpus> event_pools_; #endif // MXNET_USE_CUDA DISALLOW_COPY_AND_ASSIGN(StreamManager); }; // class StreamManager @@ -64,11 +67,12 @@ RunContext StreamManager::GetRunContext(Context const& ctx) RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: - ret = RunContext{ctx, nullptr, nullptr, false}; + ret = RunContext{ctx, nullptr, nullptr}; break; case gpu::kDevMask: { #if MXNET_USE_CUDA std::size_t use_counter; + CUDAEventPool* event_pool; { std::lock_guard lock{mutex_}; auto&& counter = gpu_cnt_.at(ctx.dev_id); @@ -84,13 +88,17 @@ RunContext StreamManager::GetRunContext(Context const& ctx) } counter = 0; } + if (event_pools_.at(ctx.dev_id) == nullptr) { + event_pools_[ctx.dev_id] = std::make_unique(ctx); + } + event_pool = event_pools_.at(ctx.dev_id).get(); use_counter = counter; counter = (counter + 1) % kStreams; } ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter), gpu_aux_streams_.at(ctx.dev_id).at(use_counter), - false}; + event_pool}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; @@ -107,18 +115,23 @@ RunContext StreamManager::GetIORunContext(Context const& ctx RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: - ret = RunContext{ctx, nullptr, nullptr, false}; + ret = RunContext{ctx, nullptr, nullptr}; break; case gpu::kDevMask: { #if MXNET_USE_CUDA + CUDAEventPool* event_pool; { std::lock_guard lock{mutex_}; if (gpu_io_streams_.at(ctx.dev_id) == nullptr) { mxnet::common::cuda::DeviceStore device_store(ctx.dev_id); gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false, ctx.dev_id); } + if (event_pools_.at(ctx.dev_id) == nullptr) { + event_pools_[ctx.dev_id] = std::make_unique(ctx); + } + event_pool = event_pools_.at(ctx.dev_id).get(); } - ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, false}; + ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, event_pool}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; @@ -147,6 +160,9 @@ void StreamManager::Finalize() { #if MXNET_USE_CUDA for (std::size_t i = 0; i < kNumGpus; ++i) { if (gpu_cnt_.at(i) != -1) { + if (event_pools_.at(i) != nullptr) { + event_pools_[i].reset(); + } for (auto&& primary_stream : gpu_streams_.at(i)) { // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream(primary_stream)); diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 58af6df66770..deba125ed8e2 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -272,7 +272,10 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end()); deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end()); this->PushAsync( - [threaded_opr](RunContext, CallbackOnComplete on_complete) { + [threaded_opr](RunContext, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + on_start(); ThreadedOpr::Delete(threaded_opr); on_complete(); }, @@ -349,7 +352,10 @@ void ThreadedEngine::PushSync(SyncFn exec_fn, const char* opr_name) { if (!bulk_size() || prop != FnProperty::kNormal || priority) { this->PushAsync( - [exec_fn](RunContext ctx, CallbackOnComplete on_complete) { + [exec_fn](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + on_start(); exec_fn(ctx); on_complete(); }, @@ -371,9 +377,12 @@ void ThreadedEngine::PushSync(SyncFn exec_fn, void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) { ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); this->PushAsync( - [delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) { + [delete_fn, threaded_var](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { // Mark variable as orphan, // so during `ThreadedEngine::OnComplete` it could be recycled. + on_start(); threaded_var->SetToDelete(); delete_fn(ctx); on_complete(); @@ -399,7 +408,10 @@ void ThreadedEngine::WaitForVar(VarHandle var) { } std::atomic done{false}; this->PushAsync( - [this, &done](RunContext, CallbackOnComplete on_complete) { + [this, &done](RunContext, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + on_start(); if (engine_info_) { LOG(INFO) << "Sync is executed"; } @@ -480,6 +492,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { } }); if (to_delete) { +#if MXNET_USE_CUDA + auto& sync_obj = i->sync_object; + { + std::lock_guard l(sync_obj.mutex); + sync_obj.reader_events.clear(); + sync_obj.writer_event.clear(); + } +#endif ThreadedVar::Delete(i); } } @@ -533,5 +553,206 @@ void ThreadedEngine::OnCompleteStatic(Engine* engine, void* opr_block_, const dm OprBlock::Delete(opr_block); } +void ThreadedEngine::OnStartStatic(Engine *engine, void *opr_block, + const dmlc::Error* error) { + // no-op +} + +#if MXNET_USE_CUDA +static inline void AddEventHelper( + std::unordered_map* events_per_stream, + const EventInfo& cuda_event) { + auto event_stream = cuda_event.stream; + if (events_per_stream->count(event_stream) > 0) { + if ((*events_per_stream)[event_stream].pool_index < cuda_event.pool_index) { + (*events_per_stream)[event_stream] = cuda_event; + } + } else { + (*events_per_stream).emplace(event_stream, cuda_event); + } +} + +void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, + const dmlc::Error* error) { + static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + if (!use_new_dep_engine) { + return; + } + ThreadedOpr *threaded_opr = static_cast(opr_block)->opr; + std::unordered_map event_per_stream; + for (auto* read_var : threaded_opr->const_vars) { + auto &sync_obj = read_var->sync_object; + std::lock_guard l(sync_obj.mutex); + auto &reader_events = sync_obj.reader_events; + // check for expired events and delete them + reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), + [&](const EventInfo e_i) { + return e_i.event.expired(); + }), reader_events.end()); + for (auto& cuda_event : reader_events) { + AddEventHelper(&event_per_stream, cuda_event); + } + if (!sync_obj.writer_event.empty()) { + if (sync_obj.writer_event[0].event.expired()) { + sync_obj.writer_event.clear(); + } else { + AddEventHelper(&event_per_stream, sync_obj.writer_event[0]); + } + } + } + + for (auto* write_var : threaded_opr->mutable_vars) { + auto &sync_obj = write_var->sync_object; + std::lock_guard l(sync_obj.mutex); + auto &reader_events = sync_obj.reader_events; + // check for expired events and delete them + reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), + [&](const EventInfo e_i) { + return e_i.event.expired(); + }), reader_events.end()); + for (auto& cuda_event : reader_events) { + AddEventHelper(&event_per_stream, cuda_event); + } + if (!sync_obj.writer_event.empty()) { + if (sync_obj.writer_event[0].event.expired()) { + sync_obj.writer_event.clear(); + } else { + AddEventHelper(&event_per_stream, sync_obj.writer_event[0]); + } + } + } + for (auto event : event_per_stream) { + auto ev = event.second.event.lock(); + MSHADOW_CUDA_CALL(cudaEventSynchronize(*ev)); + } +} + +void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info, + const dmlc::Error* error) { + static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + if (!use_new_dep_engine) { + return; + } + auto *info = reinterpret_cast(sync_info); + CHECK(info->stream != nullptr); + auto *worker_stream = reinterpret_cast *>(info->stream); + ThreadedOpr *threaded_opr = static_cast(info->opr_block)->opr; + std::unordered_map event_per_stream; + for (auto* read_var : threaded_opr->const_vars) { + auto &sync_obj = read_var->sync_object; + std::lock_guard l(sync_obj.mutex); + auto &reader_events = sync_obj.reader_events; + // check for expired events and delete them + reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), + [&](const EventInfo e_i) { + return e_i.event.expired(); + }), reader_events.end()); + for (auto& writer : sync_obj.writer_event) { + if (writer.event.expired()) { + sync_obj.writer_event.clear(); + break; + } + if (writer.stream != worker_stream->stream_) { + // if there is already a reader on the same stream as us, + // it already synced with that writer and we can rely on + // the ongoing sync + bool found = false; + for (const auto& reader : reader_events) { + if (reader.stream == worker_stream->stream_) { + found = true; + break; + } + } + if (!found) { + AddEventHelper(&event_per_stream, + writer); + } + } + } + } + for (auto* write_var : threaded_opr->mutable_vars) { + auto &sync_obj = write_var->sync_object; + std::lock_guard l(sync_obj.mutex); + // check for expired events and delete them + auto &reader_events = sync_obj.reader_events; + reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), + [&](const EventInfo e_i) { + return e_i.event.expired(); + }), reader_events.end()); + // if there are some readers, we wait for them + for (auto& cuda_event : reader_events) { + if (worker_stream->stream_ != cuda_event.stream) { + AddEventHelper(&event_per_stream, cuda_event); + } + } + if (!sync_obj.writer_event.empty()) { + if (sync_obj.writer_event[0].event.expired()) { + sync_obj.writer_event.clear(); + } else { + if (worker_stream->stream_ != sync_obj.writer_event[0].stream) { + AddEventHelper(&event_per_stream, sync_obj.writer_event[0]); + } + } + } + } + for (auto event : event_per_stream) { + auto ev = event.second.event.lock(); + MSHADOW_CUDA_CALL(cudaStreamWaitEvent(worker_stream->stream_, *ev, 0)); + } +} + +void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, + const dmlc::Error* error) { + auto *info = reinterpret_cast(sync_info); + CHECK(info->stream != nullptr); + + auto *worker_stream = reinterpret_cast *>(info->stream); + static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + + if (!use_new_dep_engine) { + worker_stream->Wait(); + ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error); + GPUWorkerSyncInfo::Delete(info); + return; + } + + ThreadedOpr *threaded_opr = static_cast(info->opr_block)->opr; + auto* event_pool = static_cast(info->event_pool); + auto[event, event_pool_idx] = event_pool->GetNextEvent(); + auto ev = event.lock(); + MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_)); + for (auto* read_var : threaded_opr->const_vars) { + auto &sync_obj = read_var->sync_object; + std::lock_guard l(sync_obj.mutex); + // If some reader event is already recorded on the same stream, + // we want to replace ourselves by it + int i; + for (i = 0; i < sync_obj.reader_events.size(); ++i) { + auto stream = sync_obj.reader_events[i].stream; + if (stream == worker_stream->stream_) { + sync_obj.reader_events[i].event = event; + sync_obj.reader_events[i].pool_index = event_pool_idx; + break; + } + } + if (i == sync_obj.reader_events.size()) { + sync_obj.reader_events.push_back({event, worker_stream->stream_, event_pool_idx}); + } + } + + for (auto* write_var : threaded_opr->mutable_vars) { + auto &sync_obj = write_var->sync_object; + std::lock_guard l(sync_obj.mutex); + sync_obj.reader_events.clear(); + sync_obj.writer_event.clear(); + sync_obj.writer_event.push_back({event, worker_stream->stream_, event_pool_idx}); + } + + ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error); + GPUWorkerSyncInfo::Delete(info); +} +#endif + + } // namespace engine } // namespace mxnet diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index bd3f34ca5500..1848e42a4e00 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -353,7 +353,8 @@ class ThreadedEngine : public Engine { * \param run_ctx runtime context used to execute the function. * \param opr_block the opr_block to be executed and deleted. */ - void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) { + void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block, + CallbackOnStart on_start, CallbackOnComplete callback) { ThreadedOpr* threaded_opr = opr_block->opr; if (opr_block->profiling && threaded_opr->opr_name.size()) { std::unique_ptr attrs; @@ -365,7 +366,6 @@ class ThreadedEngine : public Engine { new profiler::ProfileOperator(threaded_opr->opr_name.c_str(), attrs.release())); opr_block->opr_profile->startForDevice(ctx.dev_type, ctx.dev_id); } - CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); if (debug_info) { LOG(INFO) << "ExecuteOprBlock " << opr_block << "shutdown_phase=" << shutdown_phase_; @@ -381,11 +381,13 @@ class ThreadedEngine : public Engine { if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) { - threaded_opr->fn(run_ctx, callback); + threaded_opr->fn(run_ctx, on_start, callback); } else { + on_start(); callback(); } } catch (const std::exception& e) { + on_start(); threaded_opr->opr_exception = std::make_shared(std::current_exception()); callback(); @@ -408,6 +410,7 @@ class ThreadedEngine : public Engine { } } } else { + on_start(); callback(); } } @@ -429,6 +432,27 @@ class ThreadedEngine : public Engine { return bulk_size; } + protected: + static void OnStartStatic(Engine *engine, void *opr_block, + const dmlc::Error* error); + static void OnCompleteStatic(Engine *engine, void *threaded_opr, + const dmlc::Error* error); +#if MXNET_USE_CUDA + static void OnStartCPU(Engine *engine, void *opr_block, + const dmlc::Error* error); + static void OnStartGPU(Engine *engine, void *sync_info, + const dmlc::Error* error); + static void OnCompleteGPU(Engine *engine, void *sync_info, + const dmlc::Error* error); + struct GPUWorkerSyncInfo : public common::ObjectPoolAllocatable { + void *opr_block{nullptr}; + void *stream{nullptr}; + void *event_pool{nullptr}; + }; + + std::shared_ptr > objpool_gpu_sync_ref_; +#endif + private: /*! \brief structure for holding bulk execution status */ struct BulkStatus { @@ -491,7 +515,6 @@ class ThreadedEngine : public Engine { } } - static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error); /*! * \brief find exception in global_exception_refs and add it if missing * \param opr_exception the exception to be added to global_exception_refs @@ -536,16 +559,13 @@ class ThreadedEngine : public Engine { DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars); auto functions = bulk_status.functions; this->PushAsync( - [functions](RunContext ctx, CallbackOnComplete on_complete) { - ctx.is_bulk = true; + [functions](RunContext ctx, + CallbackOnStart on_start, + CallbackOnComplete on_complete) { + on_start(); for (auto& fn : *functions) { fn(ctx); } - ctx.is_bulk = false; - bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask; - if (is_gpu) { - ctx.get_stream()->Wait(); - } on_complete(); }, bulk_status.ctx, @@ -554,7 +574,6 @@ class ThreadedEngine : public Engine { FnProperty::kNormal, 0, "ImperativeBulk"); - bulk_status.functions.reset(new std::vector()); bulk_status.functions->reserve(bulk_status.bulk_size); bulk_status.const_vars.clear(); diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index b70823b2ce4c..15aa60073ca1 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -53,8 +53,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { static auto constexpr kCopyQueue = kPriority; static auto constexpr kPriorityQueue = kPriority; static auto constexpr kWorkerQueue = kFIFO; + static int constexpr kMaxStreams = 256; ThreadedEnginePerDevice() noexcept(false) { +#if MXNET_USE_CUDA + // Make sure that the pool is not destroyed before the engine + objpool_gpu_sync_ref_ = common::ObjectPool::_GetSharedRef(); + streams_.reserve(kMaxStreams); +#endif this->Start(); } ~ThreadedEnginePerDevice() noexcept(false) override { @@ -77,6 +83,15 @@ class ThreadedEnginePerDevice : public ThreadedEngine { StopNoWait(); } +#if MXNET_USE_CUDA + void WaitForAll() override { + ThreadedEngine::WaitForAll(); + for (auto s : streams_) { + s->Wait(); + } + } +#endif + void Start() override { if (is_worker_) return; @@ -107,7 +122,11 @@ class ThreadedEnginePerDevice : public ThreadedEngine { MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); #endif } - this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr, false}, opr_block); + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, + opr_block); + CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, + opr_block); + this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr}, opr_block, on_start, callback); } else { if (ctx.dev_mask() == Context::kCPU) { // CPU execution. @@ -238,6 +257,12 @@ class ThreadedEnginePerDevice : public ThreadedEngine { common::LazyAllocArray> gpu_copy_workers_; // gpu priority workers common::LazyAllocArray> gpu_priority_workers_; +#if MXNET_USE_CUDA + std::vector*> streams_; + + std::unordered_map> cuda_event_pool_per_worker_; +#endif + /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. @@ -265,9 +290,20 @@ class ThreadedEnginePerDevice : public ThreadedEngine { aux_stream = new GPUAuxStream(stream); } } while (false); + // register stream + streams_.push_back(stream); + CUDAEventPool* event_pool; + auto event_pool_it = cuda_event_pool_per_worker_.find(ctx.dev_id); + if (event_pool_it != cuda_event_pool_per_worker_.end()) { + event_pool = event_pool_it->second.get(); + } else { + auto res = cuda_event_pool_per_worker_.emplace(ctx.dev_id, + std::make_unique(ctx)); + event_pool = res.first->second.get(); + } // execute task OprBlock* opr_block; - RunContext run_ctx{ctx, stream, aux_stream, false}; + RunContext run_ctx{ctx, stream, aux_stream}; auto* task_queue = &(block->task_queue); // Don't eat up omp threads for GPU jobs. They're probably best used elsewhere, @@ -288,6 +324,13 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #if MXNET_USE_NVTX common::cuda::nvtx::gpuRangeStop(); #endif + auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); + info->opr_block = opr_block; + info->stream = stream; + info->event_pool = event_pool; + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); + CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); + this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback); } #else ready_event->signal(); @@ -303,7 +346,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { const std::shared_ptr& ready_event) { this->is_worker_ = true; auto* task_queue = &(block->task_queue); - RunContext run_ctx{ctx, nullptr, nullptr, false}; + RunContext run_ctx{ctx, nullptr, nullptr}; // execute task OprBlock* opr_block; @@ -313,7 +356,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { OpenMP::Get()->on_start_worker_thread(true); while (task_queue->Pop(&opr_block)) { - this->ExecuteOprBlock(run_ctx, opr_block); +#if MXNET_USE_CUDA + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartCPU, opr_block); +#else + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block); +#endif + CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, + opr_block); + this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback); } } diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index c0ca03991218..2d9183667ee4 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -47,6 +47,10 @@ namespace engine { class ThreadedEnginePooled : public ThreadedEngine { public: ThreadedEnginePooled() { +#if MXNET_USE_CUDA + // Make sure that the pool is not destroyed before the engine + objpool_gpu_sync_ref_ = common::ObjectPool::_GetSharedRef(); +#endif this->Start(); } @@ -55,13 +59,13 @@ class ThreadedEnginePooled : public ThreadedEngine { } void StopNoWait() { - streams_->Finalize(); task_queue_->SignalForKill(); io_task_queue_->SignalForKill(); task_queue_ = nullptr; io_task_queue_ = nullptr; thread_pool_ = nullptr; io_thread_pool_ = nullptr; + streams_->Finalize(); streams_ = nullptr; } @@ -152,7 +156,28 @@ class ThreadedEnginePooled : public ThreadedEngine { opr_block->opr->prop == FnProperty::kCopyToGPU); auto&& rctx = is_copy ? streams_->GetIORunContext(opr_block->ctx) : streams_->GetRunContext(opr_block->ctx); - this->ExecuteOprBlock(rctx, opr_block); +#if MXNET_USE_CUDA + CallbackOnStart on_start; + CallbackOnComplete callback; + if (opr_block->ctx.dev_mask() == Context::kCPU) { + on_start = this->CreateOnStart(ThreadedEngine::OnStartCPU, opr_block); + callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); + } else { + CHECK_EQ(opr_block->ctx.dev_mask(), Context::kGPU); + auto stream = rctx.get_stream(); + auto event_pool = static_cast(rctx.event_pool); + auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); + info->opr_block = opr_block; + info->stream = stream; + info->event_pool = event_pool; + on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); + callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); + } +#else // MXNET_USE_CUDA + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block); + CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); +#endif // MXNET_USE_CUDA + this->ExecuteOprBlock(rctx, opr_block, on_start, callback); } /*! * \brief Push the operation to the queue. diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5944b0a2ff22..8f71417851c5 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -695,14 +695,11 @@ inline void PushFCompute(const FCompute& fn, fn(attrs, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } DerefInputOutputRelease(inputs, outputs); }; if (CheckIfSkipEngine(attrs)) { // execute without engine - run(RunContext{ctx, nullptr, nullptr, false}); + run(RunContext{ctx, nullptr, nullptr}); } else { Engine::Get()->PushSync( run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); @@ -733,12 +730,9 @@ inline void PushFComputeEx(const FComputeEx& fn, INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req); CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA)); fn(attrs, opctx, inputsA, req, outputsA); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } }; if (cross_device_copy || CheckIfSkipEngine(attrs)) { - run(RunContext{ctx, nullptr, nullptr, false}); + run(RunContext{ctx, nullptr, nullptr}); } else { CHECK(exec_type == ExecType::kSync); Engine::Get()->PushSync( @@ -769,7 +763,9 @@ inline void PushOperator(const OpStatePtr& state, auto fcompute_ex = common::GetFCompute(op, "FStatefulComputeEx", ctx); if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) { - const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { + const auto& run = [=](RunContext rctx, + engine::CallbackOnStart on_start, + engine::CallbackOnComplete on_complete) { OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); INVALIDATE_OUTPUTS_COND( @@ -777,20 +773,19 @@ inline void PushOperator(const OpStatePtr& state, CREATE_DEFAULT_INPUTS(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", attrs, CreateDefaultInputs(&inputsA)); + on_start(); fcompute_ex(state, opctx, inputsA, req, outputsA); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && - rctx.get_stream() && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } }; // For operators with subgraphs, we need to invoke them in the main thread // instead of the threaded engine. if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) { - RunContext rctx{ctx, nullptr, nullptr, false}; - run(rctx, engine::CallbackOnComplete()); + RunContext rctx{ctx, nullptr, nullptr}; + run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, + Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, + engine::CallbackOnStart(), + engine::CallbackOnComplete()); }, ctx, read_vars, write_vars, @@ -808,7 +803,9 @@ inline void PushOperator(const OpStatePtr& state, << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { + const auto& run = [=](RunContext rctx, + engine::CallbackOnStart on_start, + engine::CallbackOnComplete on_complete) { OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; std::vector input_blobs, output_blobs; @@ -840,17 +837,16 @@ inline void PushOperator(const OpStatePtr& state, fcompute(state, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type, if necessary CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && exec_type == ExecType::kSync && rctx.get_stream() && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } DerefInputOutputRelease(inputs, outputs); }; if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) { - RunContext rctx{ctx, nullptr, nullptr, false}; - run(rctx, engine::CallbackOnComplete()); + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, + Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, + engine::CallbackOnStart(), + engine::CallbackOnComplete()); }, ctx, read_vars, write_vars, @@ -1248,7 +1244,9 @@ inline Engine::OprHandle CreateEngineOp( bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync; auto exec_fun = [execs, is_async, is_gpu](RunContext ctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete on_complete) { + on_start(); if (is_async) { execs[0]->op_ctx.async_on_complete = on_complete; } @@ -1257,10 +1255,7 @@ inline Engine::OprHandle CreateEngineOp( // call on complete only if it is async op if (!is_async) { if (is_gpu) { -#if MXNET_USE_CUDA - // Wait GPU kernel to finish. - ctx.get_stream()->Wait(); -#else +#if !MXNET_USE_CUDA LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } diff --git a/src/io/batchify.cc b/src/io/batchify.cc index 7d602fbbaa43..944bbd8cb6c6 100644 --- a/src/io/batchify.cc +++ b/src/io/batchify.cc @@ -166,7 +166,7 @@ class StackBatchify : public BatchifyFunction { // inputs[j][i].WaitToRead(); DType* ptr = (*outputs)[i].data().dptr(); auto asize = ashape.Size(); - RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr, false}; + RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr}; auto dst = TBlob(ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0); mxnet::ndarray::Copy( inputs[j][i].data(), &dst, Context::CPU(), Context::CPU(), rctx); diff --git a/src/io/dataset.cc b/src/io/dataset.cc index a461187920dd..153e3c4c056a 100644 --- a/src/io/dataset.cc +++ b/src/io/dataset.cc @@ -95,7 +95,7 @@ class RecordFileDataset final : public Dataset { const size_t size = read_buff.size(); out = NDArray(TShape({static_cast(size)}), Context::CPU(), false, mshadow::kInt8); TBlob dst = out.data(); - RunContext rctx{Context::CPU(), nullptr, nullptr, false}; + RunContext rctx{Context::CPU(), nullptr, nullptr}; mxnet::ndarray::Copy(TBlob(const_cast(reinterpret_cast(buf)), out.shape(), cpu::kDevMask, @@ -212,7 +212,7 @@ class ImageRecordFileDataset : public Dataset { size -= sizeof(header); s += sizeof(header); NDArray label = NDArray(Context::CPU(), mshadow::default_type_flag); - RunContext rctx{Context::CPU(), nullptr, nullptr, false}; + RunContext rctx{Context::CPU(), nullptr, nullptr}; if (header.flag > 0) { auto label_shape = header.flag <= 1 ? TShape(0, 1) : TShape({header.flag}); label.ReshapeAndAlloc(label_shape); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index c2dcd2026e7f..5a1df937f6eb 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -165,7 +165,10 @@ class CommCPU : public Comm { } Engine::Get()->PushAsync( - [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [reduce, this](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); ReduceSumCPU(reduce); on_complete(); }, @@ -175,7 +178,6 @@ class CommCPU : public Comm { FnProperty::kCPUPrioritized, priority, "KVStoreReduce"); - } else { // sparse reduce std::vector const_vars(src.size()); @@ -199,7 +201,10 @@ class CommCPU : public Comm { Resource rsc = ResourceManager::Get()->Request(buf_merged.ctx(), ResourceRequest(ResourceRequest::kTempSpace)); Engine::Get()->PushAsync( - [reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [reduce, buf_merged, rsc, this](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); NDArray out = buf_merged; is_serial_push_ ? ReduceSumCPUExSerial(reduce, &out) @@ -271,7 +276,10 @@ class CommCPU : public Comm { "consider create a new NDArray buffer to store the output."); } Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [=](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); const TBlob& indices = row_id.data(); NDArray temp = retained_cpu; // get rid the of const qualifier op::SparseRetainOpForwardRspImpl( @@ -679,7 +687,10 @@ class CommDevice : public Comm { } bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [=](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); const TBlob& indices = row_id.data(); using namespace mxnet::common; NDArray temp = retained_gpu; @@ -693,8 +704,6 @@ class CommDevice : public Comm { case gpu::kDevMask: { SparseRetainOpForwardRspWrapper( rctx.get_stream(), src, indices, kWriteTo, &temp); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); break; } #endif diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index bfec4fe16ee1..0205409853a3 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -151,8 +151,6 @@ void GradientCompression::Quantize(const mxnet::NDArray& from, [from, to, residual, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), residual->data(), to->data()}; Quantize1BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, from.ctx(), {from.var()}, @@ -165,8 +163,6 @@ void GradientCompression::Quantize(const mxnet::NDArray& from, [from, to, residual, threshold](mxnet::RunContext ctx) { std::vector inputs = {from.data(), residual->data(), to->data()}; Quantize2BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, from.ctx(), {from.var()}, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index bc4ce424c416..09612a5aeb60 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -421,7 +421,9 @@ class KVStoreDist : public KVStoreLocal { } gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority); auto push_to_servers = [this, key, dtype, pskv, small_buf](RunContext rctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { + on_start(); size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); char* data = static_cast(small_buf.data().dptr_); // do push. false means no delete @@ -442,7 +444,9 @@ class KVStoreDist : public KVStoreLocal { virtual void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) { auto push_to_servers = [this, key, pskv, send_buf](RunContext rctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { + on_start(); const int dtype = send_buf.dtype(); // convert to ps keys const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); @@ -464,7 +468,10 @@ class KVStoreDist : public KVStoreLocal { // push row sparse gradient virtual void PushRowSparse(int key, const NDArray& send_buf, int priority) { using namespace rowsparse; - auto push_to_servers = [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + auto push_to_servers = [this, key, send_buf](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { + on_start(); char* data = static_cast(send_buf.data().dptr_); const int64_t num_rows = send_buf.aux_shape(kIdx)[0]; const auto offsets = send_buf.aux_data(kIdx).dptr(); @@ -492,7 +499,10 @@ class KVStoreDist : public KVStoreLocal { } virtual void PullDefault(int key, const NDArray& recv_buf, int priority) { - auto pull_from_servers = [this, key, recv_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + auto pull_from_servers = [this, key, recv_buf](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { + on_start(); // convert to ps keys size_t size = recv_buf.shape().Size(); const int dtype = recv_buf.dtype(); @@ -531,7 +541,9 @@ class KVStoreDist : public KVStoreLocal { int priority) { using namespace rowsparse; auto pull_from_servers = [this, key, recv_buf, indices](RunContext rctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { + on_start(); // allocate memory for the buffer CHECK_EQ(indices.dtype(), mshadow::kInt64); const TBlob idx_data = indices.data(); @@ -573,7 +585,10 @@ class KVStoreDist : public KVStoreLocal { } virtual void PushPullDefault(int key, const NDArray& comm_buf, int priority) { - auto pushpull = [this, key, comm_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + auto pushpull = [this, key, comm_buf](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { + on_start(); size_t size = comm_buf.shape().Size(); const int dtype = comm_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 29bc45521713..14276a9c38fe 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -430,7 +430,10 @@ class KVStoreDistServer { // accumulate row_sparse gradients using namespace mshadow; Engine::Get()->PushAsync( - [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [to_merge, updateBuf, out](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); op::ElemwiseBinaryOp::ComputeEx( {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out}); on_complete(); @@ -518,7 +521,10 @@ class KVStoreDistServer { store_[master_key] = NDArray(kRowSparseStorage, dshape, Context(), true, type.dtype); } Engine::Get()->PushAsync( - [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [this, recved, stored, type](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); NDArray rsp = stored; stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])}); mshadow::Stream* s = ctx.get_stream(); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 2a0ac3a90a7f..8f9dc9b9d2b0 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -493,7 +493,10 @@ class KVStoreLocal : public KVStore { // GPU requires temp resources bool is_gpu = out.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [=](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); // copy data.data() to out.data() out.CheckAndAlloc({mshadow::Shape1(num_elements)}); TBlob out_data = out.data(); @@ -510,8 +513,6 @@ class KVStoreLocal : public KVStore { mshadow::Stream* s = rctx.get_stream(); ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); UniqueImpl(&workspace, s, out); - // wait for GPU operations to complete - s->Wait(); break; } #endif diff --git a/src/kvstore/p3store_dist.h b/src/kvstore/p3store_dist.h index 0e0aff045cb7..ed3875f96b45 100644 --- a/src/kvstore/p3store_dist.h +++ b/src/kvstore/p3store_dist.h @@ -77,9 +77,11 @@ class P3StoreDist : public KVStoreDist { LOG(FATAL) << "NotImplementedError: PushCompressed not implemented in P3StoreDist."; } - void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) override { - auto push_to_servers = [this, key, pskv, send_buf, priority](RunContext rctx, - Engine::CallbackOnComplete cb) { + void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) override { + auto push_to_servers = [this, key, pskv, send_buf, priority] (RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { + on_start(); const int dtype = send_buf.dtype(); // convert to ps keys const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); @@ -87,7 +89,6 @@ class P3StoreDist : public KVStoreDist { // do push. false means no delete ps::SArray vals(data, size, false); int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); - size_t off = 0; auto counter = new std::atomic(pskv.keys.size()); for (size_t idx = 0; idx < pskv.keys.size(); idx++) { @@ -127,7 +128,9 @@ class P3StoreDist : public KVStoreDist { CHECK(gradient_compression_->get_type() == CompressionType::kNone) << "Gradient compression not supported in P3StoreDist."; auto pull_from_servers = [this, key, recv_buf, priority](RunContext rctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { + on_start(); // convert to ps keys size_t size = recv_buf.shape().Size(); const int dtype = recv_buf.dtype(); @@ -181,7 +184,9 @@ class P3StoreDist : public KVStoreDist { CHECK(gradient_compression_->get_type() == CompressionType::kNone) << "Compression not supported in P3StoreDist"; auto pushpull = [this, key, comm_buf, priority](RunContext rctx, + Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { + on_start(); size_t size = comm_buf.shape().Size(); const int dtype = comm_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 3e64a8db88ef..91ed70ac9155 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -134,7 +134,22 @@ NDArray::Chunk::~Chunk() { #endif if (auto engine = engine_ref_.lock()) { engine->DeleteVariable( - [mem, skip_free](RunContext s) { + [mem, skip_free, var = this->var](RunContext s) mutable { +#if MXNET_USE_CUDA + auto &sync_obj = var->sync_object; + Storage::SyncObj storage_sync_obj; + { + std::lock_guard l(sync_obj.mutex); + for (auto& ev : sync_obj.reader_events) { + storage_sync_obj.events.push_back(ev.event); + } + if (!sync_obj.writer_event.empty()) { + auto ev = sync_obj.writer_event[0]; + storage_sync_obj.events.push_back(ev.event); + } + } + mem.h.sync_obj = storage_sync_obj; +#endif if (skip_free == false) { #if MXNET_USE_ONEDNN == 1 if (mem.mem) { @@ -746,16 +761,19 @@ void NDArray::Reorder2DefaultAsync() const { std::vector mutable_vars(1, this->var()); NDArray tmp = *this; Engine::Get()->PushAsync( - [tmp](RunContext ctx, Engine::CallbackOnComplete on_complete) { - tmp.ptr_->Reorder2Default(); - on_complete(); - }, - ctx(), - const_vars, - mutable_vars, - FnProperty::kNormal, - 0, - "Reorder2Default"); + [tmp](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + tmp.ptr_->Reorder2Default(); + on_complete(); + }, + ctx(), + const_vars, + mutable_vars, + FnProperty::kNormal, + 0, + "Reorder2Default"); } // now just support bf16->fp32 @@ -778,20 +796,23 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const { NDArray tmp = *this; const auto version = this->version(); Engine::Get()->PushAsync( - [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { - // MXNet will try to reuse NDArray from memory planning, so we need to ensure - // the NDArray is still holding the original trunk data. - if (tmp.version() == version) { - tmp.ptr_->MKLDNNDataReorder(desc); - } - on_complete(); - }, - ctx(), - const_vars, - mutable_vars, - FnProperty::kNormal, - 0, - "Reorder"); + [tmp, version, desc](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + // MXNet will try to reuse NDArray from memory planning, so we need to ensure + // the NDArray is still holding the original trunk data. + if (tmp.version() == version) { + tmp.ptr_->MKLDNNDataReorder(desc); + } + on_complete(); + }, + ctx(), + const_vars, + mutable_vars, + FnProperty::kNormal, + 0, + "Reorder"); } const mkldnn::memory* NDArray::GetMKLDNNData() const { @@ -993,8 +1014,6 @@ void TernaryOp(const NDArray& lhs, const NDArray& mhs, const NDArray& rhs, NDArr [lhs, mhs, rhs, ret](RunContext ctx) { TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, @@ -1081,8 +1100,6 @@ void BinaryOpKernel(const NDArray& lhs, const NDArray& rhs, NDArray* out) { TBlob tmp = ret.data(); mshadow::Stream* s = ctx.get_stream(); ndarray::BinaryOpKernelImpl(s, lhs.data(), rhs.data(), &tmp); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, @@ -1132,8 +1149,6 @@ void BinaryOp(const NDArray& lhs, const NDArray& rhs, NDArray* out) { [lhs, rhs, ret](RunContext ctx) { TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, @@ -1177,7 +1192,6 @@ void SetValueOp(const real_t& rhs, NDArray* out) { ctx.get_stream()->Wait(); break; } -#endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } @@ -1234,8 +1248,6 @@ void ScalarOp(const NDArray& lhs, const real_t& rhs, NDArray* out) { [lhs, rhs, ret](RunContext ctx) { TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, @@ -1461,7 +1473,10 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushAsync( - [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [from, to, requested](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); CopyFromToImpl(from, to, ctx, requested); on_complete(); }, @@ -1475,9 +1490,11 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushAsync( - [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [from, to, requested](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); CopyFromToImpl(from, to, ctx, requested); - ctx.get_stream()->Wait(); on_complete(); }, to.ctx(), @@ -1488,11 +1505,13 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op "CopyCPU2GPU"); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushAsync( - [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [from, to, requested](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); CopyFromToImpl(from, to, ctx, requested); - ctx.get_stream()->Wait(); on_complete(); - }, + }, from.ctx(), const_vars, mutable_vars, @@ -1501,9 +1520,11 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op "CopyGPU2CPU"); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushAsync( - [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) { + [from, to, requested](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); CopyFromToImpl(from, to, ctx, requested); - ctx.get_stream()->Wait(); on_complete(); }, from.ctx(), @@ -1574,11 +1595,9 @@ void ElementwiseSum(const std::vector& source, NDArray* out, int priori } TBlob tmp = ret.data(); ndarray::ElementwiseSum(source_tblob, &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, out->ctx(), - const_vars, + const_vars, {ret.var()}, FnProperty::kNormal, priority, @@ -1604,8 +1623,6 @@ void ElementwiseSum(const std::vector& source, NDArray* out, int priori #if MXNET_USE_CUDA case gpu::kDevMask: { mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, source, &result); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); break; } #endif @@ -1699,8 +1716,6 @@ void SampleOP(const real_t& a, const real_t& b, NDArray* out) { [a, b, resource, ret](RunContext ctx) { TBlob tmp = ret.data(); ndarray::EvalRandom(a, b, resource, &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); }, out->ctx(), {}, @@ -2179,17 +2194,19 @@ void NDArray::SyncCopyFromCPU(const void* data, size_t size) const { if (this->ctx().dev_mask() == cpu::kDevMask) { this->WaitToWrite(); - RunContext rctx{this->ctx(), nullptr, nullptr, false}; + RunContext rctx{this->ctx(), nullptr, nullptr}; TBlob dst = this->data(); ndarray::Copy(src, &dst, Context::CPU(), Context::CPU(), rctx); } else { #if MXNET_USE_CUDA Engine::Get()->PushAsync( - [&](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [&](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); TBlob dst = this->data(); - ndarray::Copy(src, &dst, Context::CPU(), this->ctx(), rctx); - // Wait GPU kernel to complete - rctx.get_stream()->Wait(); + ndarray::Copy(src, &dst, + Context::CPU(), this->ctx(), rctx); on_complete(); }, this->ctx(), @@ -2265,11 +2282,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { #if MXNET_USE_CUDA if (src_dev_mask == cpu::kDevMask && dst_dev_mask == gpu::kDevMask) { Engine::Get()->PushAsync( - [&](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [&](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); - rctx.get_stream()->Wait(); on_complete(); }, this->ctx(), @@ -2280,11 +2299,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { "SyncCopyFromNDArrayCPU2GPU"); } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == cpu::kDevMask) { Engine::Get()->PushAsync( - [&](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [&](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); - rctx.get_stream()->Wait(); on_complete(); }, src.ctx(), @@ -2295,11 +2316,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { "SyncCopyFromNDArrayGPU2CPU"); } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == gpu::kDevMask) { Engine::Get()->PushAsync( - [&](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [&](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); - rctx.get_stream()->Wait(); on_complete(); }, this->ctx(), @@ -2343,7 +2366,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const { this->WaitToRead(); if (this->ctx().dev_mask() == cpu::kDevMask) { - RunContext rctx{this->ctx(), nullptr, nullptr, false}; + RunContext rctx{this->ctx(), nullptr, nullptr}; NDArray src = *this; #if MXNET_USE_ONEDNN == 1 if (src.IsMKLDNNData()) @@ -2353,10 +2376,41 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const { } else { #if MXNET_USE_CUDA Engine::Get()->PushAsync( - [&](RunContext rctx, Engine::CallbackOnComplete on_complete) { - ndarray::Copy(this->data(), &dst, this->ctx(), Context::CPU(), rctx); - // Wait GPU kernel to complete - rctx.get_stream()->Wait(); + [&](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + { + auto var = this->var(); + auto& sync_obj = var->sync_object; + std::lock_guard lock{sync_obj.mutex}; + bool has_writer = false; + std::shared_ptr w_ev_ptr; + if (!sync_obj.writer_event.empty()) { + w_ev_ptr = sync_obj.writer_event[0].event.lock(); + has_writer = w_ev_ptr ? true : false; + } + for (auto ev : sync_obj.reader_events) { + auto event_ptr = ev.event.lock(); + if (!event_ptr) { + continue; + } + cudaEvent_t event = *event_ptr; + if (has_writer) { + auto w_ev = sync_obj.writer_event[0]; + if (w_ev.stream == ev.stream) { + event = w_ev.pool_index > ev.pool_index ? *w_ev_ptr : *event_ptr; + has_writer = false; + } + } + CUDA_CALL(cudaEventSynchronize(event)); + } + if (has_writer) { + CUDA_CALL(cudaEventSynchronize(*w_ev_ptr)); + } + } + ndarray::Copy(this->data(), &dst, + this->ctx(), Context::CPU(), rctx); on_complete(); }, this->ctx(), @@ -2389,7 +2443,6 @@ void NDArray::SyncCheckFormat(const bool full_check) const { Engine::Get()->PushSync( [&](RunContext rctx) { common::CheckFormatWrapper(rctx, *this, err_cpu, full_check); - rctx.get_stream()->Wait(); }, this->ctx(), {this->var()}, @@ -2428,7 +2481,10 @@ void NDArray::WaitToWrite() const { Imperative::DCInfo::Compute(*this); // Push an empty mutable function to flush all preceding reads to the variable. Engine::Get()->PushAsync( - [](RunContext, Engine::CallbackOnComplete on_complete) { on_complete(); }, + [](RunContext, Engine::CallbackOnStart on_start, Engine::CallbackOnComplete on_complete) { + on_start(); + on_complete(); + }, Context{}, {}, {ptr_->var}); diff --git a/src/operator/custom/ndarray_op.cc b/src/operator/custom/ndarray_op.cc index ac59d5f22b43..fe07a3e5bba6 100644 --- a/src/operator/custom/ndarray_op.cc +++ b/src/operator/custom/ndarray_op.cc @@ -87,7 +87,10 @@ void NDArrayOp::Forward(const OpContext& ctx, CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward)); Engine::Get()->PushAsync( - [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [ndcpy, ctx](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); ctx.async_on_complete(); on_complete(); }, @@ -144,7 +147,10 @@ void NDArrayOp::Backward(const OpContext& ctx, CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward)); Engine::Get()->PushAsync( - [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [ndcpy, ctx](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); ctx.async_on_complete(); on_complete(); }, diff --git a/src/operator/operator_util.cc b/src/operator/operator_util.cc index 07eba962b400..8235827a91ee 100644 --- a/src/operator/operator_util.cc +++ b/src/operator/operator_util.cc @@ -486,22 +486,16 @@ void SimpleOpRegEntryImpl::RegisterSourceImperative() { SourceFunction fun = fsource_[dev_mask]; OpReqType req = kWriteTo; - Engine::Get()->PushSync( - [ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(env, &tmp, req, ctx); -#if MXNET_USE_CUDA - if (dev_mask == gpu::kDevMask) { - ctx.get_stream()->Wait(); - } -#endif - }, - ret.ctx(), - {}, - write_vars, - FnProperty::kNormal, - 0, - "RegisterSourceImperative"); + Engine::Get()->PushSync([ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(env, &tmp, req, ctx); + }, + ret.ctx(), + {}, + write_vars, + FnProperty::kNormal, + 0, + "RegisterSourceImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(0).set_num_mutate_vars(1); @@ -668,22 +662,16 @@ void SimpleOpRegEntryImpl::RegisterUnaryImperative() { << "inplace operation is not enabled for operator " << name; } - Engine::Get()->PushSync( - [src, ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(src.data(), env, &tmp, req, ctx); -#if MXNET_USE_CUDA - if (dev_mask == gpu::kDevMask) { - ctx.get_stream()->Wait(); - } -#endif - }, - src.ctx(), - const_vars, - write_vars, - FnProperty::kNormal, - 0, - "RegisterUnaryImperative"); + Engine::Get()->PushSync([src, ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(src.data(), env, &tmp, req, ctx); + }, + src.ctx(), + const_vars, + write_vars, + FnProperty::kNormal, + 0, + "RegisterUnaryImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(1).set_num_mutate_vars(1); @@ -950,22 +938,16 @@ void SimpleOpRegEntryImpl::RegisterBinaryImperative() { << " warning, perform inplace operation with right operand, may not be supported"; } - Engine::Get()->PushSync( - [lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx); -#if MXNET_USE_CUDA - if (dev_mask == gpu::kDevMask) { - ctx.get_stream()->Wait(); - } -#endif - }, - lhs.ctx(), - const_vars, - write_vars, - FnProperty::kNormal, - 0, - "RegisterBinaryImperative"); + Engine::Get()->PushSync([lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx); + }, + lhs.ctx(), + const_vars, + write_vars, + FnProperty::kNormal, + 0, + "RegisterBinaryImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(2).set_num_mutate_vars(1); diff --git a/src/resource.cc b/src/resource.cc index 899f58d74df3..7c434c80abf0 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -266,17 +266,20 @@ class ResourceManagerImpl : public ResourceManager { inline void Seed(uint32_t seed) { mshadow::Random* r = prnd; Engine::Get()->PushAsync( - [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) { - r->set_stream(rctx.get_stream()); - r->Seed(seed); - on_complete(); - }, - ctx, - {}, - {resource.var}, - FnProperty::kNormal, - 0, - "ResourceRandomSetSeed"); + [r, seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + r->set_stream(rctx.get_stream()); + r->Seed(seed); + on_complete(); + }, + ctx, + {}, + {resource.var}, + FnProperty::kNormal, + 0, + "ResourceRandomSetSeed"); } }; @@ -341,39 +344,41 @@ class ResourceManagerImpl : public ResourceManager { uint32_t current_seed = p->ctx.dev_id + i * kMaxNumGPUs + seed * kRandMagic; Resource* r = &(p->resource[i]); Engine::Get()->PushAsync( - [r, current_seed](RunContext rctx, Engine::CallbackOnComplete on_complete) { - auto state_space = static_cast(r->ptr_); - mshadow::Stream* stream = rctx.get_stream(); - CHECK_EQ(state_space->ctx.dev_id, stream->dev_id) - << "The device id of cudnn dropout state space doesn't match that from stream."; - if (!state_space->handle.size) { - // not allocated yet - size_t dropout_state_size; - CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); - // reserve GPU space - Storage::Get()->DirectFree( - Storage::Get()->Alloc(dropout_state_size, state_space->ctx)); - state_space->GetSpace(dropout_state_size, "cudnn_dropout_state"); - } - cudnnDropoutDescriptor_t temp_descriptor; - CUDNN_CALL(cudnnCreateDropoutDescriptor(&temp_descriptor)); - CUDNN_CALL(cudnnSetDropoutDescriptor(temp_descriptor, - stream->dnn_handle_, - 0.5, - state_space->handle.dptr, - state_space->handle.size, - current_seed)); - CUDNN_CALL(cudnnDestroyDropoutDescriptor(temp_descriptor)); - cudaStream_t cuda_stream = mshadow::Stream::GetStream(stream); - cudaStreamSynchronize(cuda_stream); - on_complete(); - }, - p->ctx, - {}, - {r->var}, - FnProperty::kNormal, - 0, - "CUDNNDropoutDescriptorSeed"); + [r, current_seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + auto state_space = static_cast(r->ptr_); + mshadow::Stream* stream = rctx.get_stream(); + CHECK_EQ(state_space->ctx.dev_id, stream->dev_id) + << "The device id of cudnn dropout state space doesn't match that from stream."; + if (!state_space->handle.size) { + // not allocated yet + size_t dropout_state_size; + CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); + // reserve GPU space + Storage::Get()->DirectFree( + Storage::Get()->Alloc(dropout_state_size, state_space->ctx)); + state_space->GetSpace(dropout_state_size, "cudnn_dropout_state"); + } + cudnnDropoutDescriptor_t temp_descriptor; + CUDNN_CALL(cudnnCreateDropoutDescriptor(&temp_descriptor)); + CUDNN_CALL(cudnnSetDropoutDescriptor(temp_descriptor, stream->dnn_handle_, + 0.5, + state_space->handle.dptr, + state_space->handle.size, + current_seed)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(temp_descriptor)); + cudaStream_t cuda_stream = mshadow::Stream::GetStream(stream); + cudaStreamSynchronize(cuda_stream); + on_complete(); + }, + p->ctx, + {}, + {r->var}, + FnProperty::kNormal, + 0, + "CUDNNDropoutDescriptorSeed"); } p->curr_ptr.store(0); @@ -448,16 +453,19 @@ class ResourceManagerImpl : public ResourceManager { inline void SeedOne(size_t i, uint32_t seed) { common::random::RandGenerator* r = sampler[i]; Engine::Get()->PushAsync( - [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) { - r->Seed(rctx.get_stream(), seed); - on_complete(); - }, - ctx, - {}, - {resource[i].var}, - FnProperty::kNormal, - 0, - "ResourceNativeRandomSetSeed"); + [r, seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + r->Seed(rctx.get_stream(), seed); + on_complete(); + }, + ctx, + {}, + {resource[i].var}, + FnProperty::kNormal, + 0, + "ResourceNativeRandomSetSeed"); } // get next resource in round roubin matter inline Resource GetNext() { diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h index ee8be75aa2c9..a7d7af4d9950 100644 --- a/src/storage/gpu_device_storage.h +++ b/src/storage/gpu_device_storage.h @@ -61,6 +61,14 @@ inline void GPUDeviceStorage::Free(Storage::Handle handle) { #if MXNET_USE_NCCL std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); #endif // MXNET_USE_NCCL +#if MXNET_USE_CUDA + for (auto ev : handle.sync_obj.events) { + auto valid_ev = ev.lock(); + if (valid_ev) { + MSHADOW_CUDA_CALL(cudaEventSynchronize(*valid_ev)); + } + } +#endif CUDA_CALL(cudaFree(handle.dptr)) profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle); } diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index a58fc6e18f1e..fff4549e837b 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "./storage_manager.h" #include "../profiler/storage_profiler.h" @@ -129,7 +130,9 @@ class PooledStorageManager : public StorageManager, public BucketingStrategy, pu void Free(Storage::Handle handle) override { // Insert returned memory in cache std::lock_guard lock(Storage::Get()->GetMutex(dev_type_)); - StoringMethod::InsertInCache(BucketingStrategy::get_bucket(handle.size), handle.dptr); + StoringMethod::InsertInCache(BucketingStrategy::get_bucket(handle.size), + handle.dptr, + handle.sync_obj); } void DirectFree(Storage::Handle handle) override { @@ -154,7 +157,7 @@ class PooledStorageManager : public StorageManager, public BucketingStrategy, pu UNSET_DEVICE(device_store); } - bool MemoryIsAvalable(size_t roundSize) const { + bool MemoryIsAvailable(size_t roundSize) const { const auto free = contextHelper_->freeMemorySize(); return free > roundSize && memory_allocation_limit_ <= free - roundSize; } @@ -178,7 +181,7 @@ void PooledStorageManager::Alloc(Storage::Hand if (!reuse_pool) { SET_DEVICE(device_store, contextHelper_, handle->ctx, true); roundSize = BucketingStrategy::RoundAllocSizeForBucket(bucket_id); - if (!MemoryIsAvalable(roundSize)) + if (!MemoryIsAvailable(roundSize)) ReleaseAllNoLock(false); void* ret = nullptr; @@ -204,7 +207,19 @@ void PooledStorageManager::Alloc(Storage::Hand handle->dptr = ret; } else { // Reusing memory - handle->dptr = reuse_pool->back(); + auto ptr_syncobj = reuse_pool->back(); + handle->dptr = ptr_syncobj.first; + if (dev_type_ == Context::kGPU) { + handle->sync_obj = ptr_syncobj.second; +#if MXNET_USE_CUDA + for (auto ev : handle->sync_obj.events) { + auto valid_ev = ev.lock(); + if (valid_ev) { + MSHADOW_CUDA_CALL(cudaEventSynchronize(*valid_ev)); + } + } +#endif + } reuse_pool->pop_back(); } #if MXNET_USE_CUDA @@ -378,11 +393,11 @@ class RoundPower2 : public RoundHelper { class UnorderedMapContainer { protected: inline void InitContainer(const RoundHelper* p) {} - inline void InsertInCache(size_t key, void* dptr) { - memory_pool_[key].push_back(dptr); + inline void InsertInCache(size_t key, void* dptr, Storage::SyncObj sync_obj) { + memory_pool_[key].emplace_back(dptr, sync_obj); } - inline std::vector* GetMemStorage(size_t key) { + inline std::vector>* GetMemStorage(size_t key) { auto&& reuse_it = memory_pool_.find(key); return reuse_it != memory_pool_.end() && reuse_it->second.size() ? &reuse_it->second : nullptr; } @@ -392,8 +407,8 @@ class UnorderedMapContainer { size_t released_memory = 0; for (auto&& i : memory_pool_) { for (auto&& j : i.second) { - contextHelper->Free(j); - GPU_PROFILER_ON_FREE(profilerGPU, j); + contextHelper->Free(j.first); + GPU_PROFILER_ON_FREE(profilerGPU, j.first); } released_memory += i.first * i.second.size(); i.second.clear(); @@ -403,7 +418,7 @@ class UnorderedMapContainer { } private: - std::unordered_map> memory_pool_; + std::unordered_map>> memory_pool_; }; // class UnorderedMapContainer /*! @@ -422,11 +437,11 @@ class VectorContainer { memory_pool_.resize(vector_size); } - inline void InsertInCache(size_t idx, void* dptr) { - memory_pool_[idx].push_back(dptr); + inline void InsertInCache(size_t idx, void* dptr, Storage::SyncObj sync_obj) { + memory_pool_[idx].emplace_back(dptr, sync_obj); } - std::vector* GetMemStorage(size_t idx) { + std::vector>* GetMemStorage(size_t idx) { auto&& reuse_pool = memory_pool_[idx]; return reuse_pool.size() ? &reuse_pool : nullptr; } @@ -439,8 +454,8 @@ class VectorContainer { continue; for (auto& j : memory_pool_[i]) { - contextHelper->Free(j); - GPU_PROFILER_ON_FREE(profilerGPU, j); + contextHelper->Free(j.first); + GPU_PROFILER_ON_FREE(profilerGPU, j.first); } released_memory += rndHelper->get_size(i) * memory_pool_[i].size(); memory_pool_[i].clear(); @@ -449,7 +464,7 @@ class VectorContainer { } private: - std::vector> memory_pool_; + std::vector>> memory_pool_; size_t first_bucket_; }; // class VectorContainer diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 04760b346da6..90aa302b1476 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -256,7 +256,7 @@ const std::string env_var_name(const char* dev_type, env_var_type type) { } // namespace storage -std::shared_ptr Storage::_GetSharedRef() { +const std::shared_ptr &Storage::_GetSharedRef() { #ifdef __MXNET_JS__ // dummy code needed for emscripten code to pass // do not know why, the new will be NULLPTR diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 11ca2c94c1c0..49a5abaed6ed 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -110,8 +110,10 @@ double EvaluateWorkloads(const std::vector& workloads, if (engine == nullptr) { EvaluateWorkload(wl, data); } else { - auto func = [wl, data](RunContext ctx, Engine::CallbackOnComplete cb) { - EvaluateWorkload(wl, data); cb(); + auto func = [wl, data](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { + on_start(); EvaluateWorkload(wl, data); cb(); }; std::vector reads; for (auto i : wl.reads) { @@ -182,7 +184,7 @@ TEST(Engine, RandSumExpr) { void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } -void FooAsyncFunc(void*, void* cb_ptr, void* param) { +void FooAsyncFunc(void*, void*, void* cb_ptr, void* param) { if (param == nullptr) { LOG(INFO) << "The fox asynchronously says receiving nothing."; } else { @@ -346,7 +348,10 @@ TEST(Engine, basics) { printf("============= Test #1 ==============\n"); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [i](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { + on_start(); Foo(ctx, i); std::this_thread::sleep_for(std::chrono::seconds{1}); cb(); @@ -368,7 +373,10 @@ TEST(Engine, basics) { oprs.clear(); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [i](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { + on_start(); Foo(ctx, i); std::this_thread::sleep_for(std::chrono::milliseconds{500}); cb(); @@ -394,8 +402,11 @@ TEST(Engine, basics) { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { std::this_thread::sleep_for(std::chrono::seconds{2}); + on_start(); Foo(ctx, 42); cb(); }, @@ -414,7 +425,10 @@ TEST(Engine, basics) { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { + on_start(); Foo(ctx, 42); std::this_thread::sleep_for(std::chrono::seconds{2}); cb(); @@ -452,7 +466,10 @@ TEST(Engine, VarVersion) { EXPECT_EQ(var->version(), 0U); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [i](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { + on_start(); Foo(ctx, i); cb(); }, @@ -473,7 +490,10 @@ TEST(Engine, VarVersion) { oprs.clear(); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { + [i](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { + on_start(); Foo(ctx, i); cb(); }, diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 134eab397640..77706bf8a220 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -496,92 +496,6 @@ def tensor_size(big_tensor_bytes): # Evaluate model net(data_in).asnumpy() -# isolated execution bulking test function to be invoked with different env var settings - - -@mx.util.use_np -def _test_bulking_in_process(seed, time_per_iteration): - # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused. - class Flip(gluon.HybridBlock): - def __init__(self, **kwargs): - super(Flip, self).__init__(**kwargs) - - def forward(self, x): - return mx.np.flip(x, axis=0) - - def get_net(num_ops): - net = nn.HybridSequential() - for _ in range(num_ops): - net.add(Flip()) - return net - - data_shape = (10,) - num_ops = 1000 - num_iterations = 20 - - # build model - x = mx.np.zeros(data_shape) - x.attach_grad() - dy = mx.np.ones(data_shape) - net = get_net(num_ops) - net.hybridize(static_alloc=True, static_shape=True) - - # time a number of forward() and backward() executions after some warm-up iterations - warmups = 1 - for i in range(num_iterations + warmups): - with autograd.record(): - if i == warmups: - start = time.time() - y = net(x) - y.backward(dy) - x.grad.wait_to_read() - - time_per_iteration.value = (time.time() - start) / num_iterations - -def _test_bulking(test_bulking_func): - # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) - test_cases = [(0, 0, True), (1, 1, True), (15, 15, False), - (15, 0, True), (0, 15, True), (15, 15, True)] - times = {} - times_str = '' - for seg_sizes in test_cases: - # Create shared variable to return measured time from test process - time_per_iteration = mp.Manager().Value('d', 0.0) - - if not run_in_spawned_process(test_bulking_func, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]), - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]), - 'MXNET_EXEC_BULK_EXEC_TRAIN': str(seg_sizes[2])}, - time_per_iteration): - # skip test since the python version can't run it properly. Warning msg was logged. - return - times[seg_sizes] = time_per_iteration.value - times_str += \ - '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( - seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - - fastest_non_bulked_time = min(times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)]) - slowest_half_bulked_time = max(times[(0, 15, True)], times[(15, 0, True)]) - fastest_half_bulked_time = min(times[(0, 15, True)], times[(15, 0, True)]) - fully_bulked_time = times[(15, 15, True)] - - print(times_str) - # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, - # slower than both half-bulked times[0,15,True] and times[15,0,True] - assert slowest_half_bulked_time < fastest_non_bulked_time, \ - 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ - .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) - # The fully bulked times[15,15,True] should be faster than both half-bulked runs - assert fully_bulked_time < fastest_half_bulked_time, \ - 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ - .format(fully_bulked_time - fastest_half_bulked_time, times_str) - -@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970') -def test_bulking_gluon_gpu(): - _test_bulking(_test_bulking_in_process) - - -@mx.util.use_np def test_hybridblock_mix_ctx_raise(): class FooHybrid(gluon.HybridBlock): def forward(self, a, b): diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 9ce005b2f72f..195e40906b71 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -48,7 +48,6 @@ from test_sparse_operator import * from test_ndarray import * from test_subgraph_op import * -from test_gluon_gpu import _test_bulking from test_contrib_operator import test_multibox_target_op from test_optimizer import test_adamW del test_custom_op_fork #noqa @@ -2115,78 +2114,6 @@ def test_bilinear_sampler_versions(): if req_dict['grid'] is 'write': assert_almost_equal(exe.grad_dict['grid'], exe_list[ref_idx].grad_dict['grid'], rtol=1e-3, atol=1e-5) - -# isolated execution bulking test function to be invoked with different env var settings -def _test_bulking_in_process(seed, time_per_iteration): - data_shape = (10,) - num_ops = 1000 - num_iterations = 20 - - ctx = default_context() - # build symbol - X = mx.sym.Variable('X') - sym = mx.sym.flip(X, axis=0) - for _ in range(num_ops-1): - sym = mx.sym.flip(sym, axis=0) - x = mx.ndarray.zeros(data_shape) - dx = mx.ndarray.zeros(data_shape) - dy = mx.ndarray.ones(data_shape) - exe = sym._bind(ctx=ctx, args=[x], args_grad = {'X':dx}) - - # time a number of forward() and backward() executions after some warm-up iterations - warmups = 1 - for i in range(num_iterations+warmups): - if i == warmups: - start = time.time() - exe.forward(is_train=True) - exe.backward(dy) - dx.wait_to_read() - time_per_iteration.value = (time.time() - start) / num_iterations - - -@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/16517') -def test_bulking_operator_gpu(): - _test_bulking(_test_bulking_in_process) - - -@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970') -def test_bulking(): - # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) - test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] - times = {} - times_str = '' - for seg_sizes in test_cases: - # Create shared variable to return measured time from test process - time_per_iteration = mp.Manager().Value('d', 0.0) - if not run_in_spawned_process(_test_bulking_in_process, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : str(seg_sizes[0]), - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : str(seg_sizes[1]), - 'MXNET_EXEC_BULK_EXEC_TRAIN' : str(seg_sizes[2])}, - time_per_iteration): - # skip test since the python version can't run it properly. Warning msg was logged. - return - times[seg_sizes] = time_per_iteration.value - times_str += \ - '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( - seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - - fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) - slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) - fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) - fully_bulked_time = times[(15,15,True)] - - print(times_str) - # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, - # slower than both half-bulked times[0,15,True] and times[15,0,True] - assert slowest_half_bulked_time < fastest_non_bulked_time, \ - 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ - .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) - # The fully bulked times[15,15,True] should be faster than both half-bulked runs - assert fully_bulked_time < fastest_half_bulked_time, \ - 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ - .format(fully_bulked_time - fastest_half_bulked_time, times_str) - - @pytest.mark.serial def test_allclose_function_gpu(): allclose_function([mx.cpu(), mx.gpu(0)]) From 6aae92bd43b164276c41403a926667274842baba Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 8 Jun 2021 14:55:48 +0900 Subject: [PATCH 2/8] Temporarely skip byteps test Signed-off-by: Serge Panev --- ci/jenkins/Jenkinsfile_unix_gpu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index 2beb0f4aa1f4..53224e947bc5 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -49,7 +49,8 @@ core_logic: { custom_steps.test_unix_cpp_package_gpu('gpu'), // TODO(szha): fix and reenable the hanging issue. tracked in #18098 // custom_steps.test_unix_distributed_kvstore_gpu('gpu'), - custom_steps.test_unix_byteps_gpu('gpu'), + // TODO(spanev): reenable when byteps is updated with the new dep engine API + // custom_steps.test_unix_byteps_gpu('gpu'), ]) } , From 97b51f3416febcadf2e542aed30910221e6dad70 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 29 Jun 2021 14:21:01 +0900 Subject: [PATCH 3/8] Fix typo Signed-off-by: Serge Panev --- include/mxnet/base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index aecd7a300394..61a25cb4738b 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -348,7 +348,7 @@ struct RunContext { */ void *aux_stream; /*! - * \brief pointer to the cuda event pool used by the dependecy engine + * \brief pointer to the cuda event pool used by the dependency engine */ void *event_pool = nullptr; /*! From 858968afe0954927ffb8d7c99f61825675c75999 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 7 Sep 2021 15:53:10 +0200 Subject: [PATCH 4/8] Fix lint Signed-off-by: Serge Panev --- src/ndarray/ndarray.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 91ed70ac9155..e4387b3d218c 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1511,7 +1511,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op on_start(); CopyFromToImpl(from, to, ctx, requested); on_complete(); - }, + }, from.ctx(), const_vars, mutable_vars, @@ -1597,7 +1597,7 @@ void ElementwiseSum(const std::vector& source, NDArray* out, int priori ndarray::ElementwiseSum(source_tblob, &tmp, ctx); }, out->ctx(), - const_vars, + const_vars, {ret.var()}, FnProperty::kNormal, priority, From c9ff58bd9cb1550299d294983e04161fffb82278 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 7 Sep 2021 15:57:50 +0200 Subject: [PATCH 5/8] Fix bad cast Signed-off-by: Serge Panev --- include/mxnet/engine.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index e37c0e646e29..ed77bbf97d81 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -77,7 +77,7 @@ class CUDAEventPool final { } inline std::pair, uint64_t> GetNextEvent() noexcept { - int c = counter_++; + uint64_t c = counter_++; return {events_.at((c) % kPoolSize).GetEvent(), c}; } From ad587d7b2c8853e36dcefe99ee8acaf35f2930e2 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Wed, 8 Sep 2021 12:37:27 +0200 Subject: [PATCH 6/8] Move Async engine tag to MXNET_ENGINE_TYPE Signed-off-by: Serge Panev --- src/engine/engine.cc | 8 ++++++++ src/engine/threaded_engine.cc | 13 ++++++++++--- src/ndarray/ndarray.cc | 3 +-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/engine/engine.cc b/src/engine/engine.cc index c491b5cf0b64..429e547ca915 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -36,6 +36,14 @@ inline Engine* CreateEngine() { type = "ThreadedEnginePerDevice"; std::string stype = type; + // The async tag is used later to determine if we use the GPU dependecy engine + std::string async_engine_tag = "Async"; + auto tag_pos = stype.find(async_engine_tag); + if (tag_pos != std::string::npos + && tag_pos + async_engine_tag.length() == stype.length()) { + stype = stype.substr(0, tag_pos); + } + Engine* ret = nullptr; #if MXNET_PREDICT_ONLY == 0 if (stype == "NaiveEngine") { diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index deba125ed8e2..2308777ab2fa 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -572,9 +572,16 @@ static inline void AddEventHelper( } } +static inline bool IsEngineAsync() { + std::string type = dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string("")); + std::string async_engine_tag("Async"); + auto tag_pos = type.find(async_engine_tag); + return tag_pos != std::string::npos; +} + void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, const dmlc::Error* error) { - static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { return; } @@ -629,7 +636,7 @@ void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info, const dmlc::Error* error) { - static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { return; } @@ -707,7 +714,7 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, CHECK(info->stream != nullptr); auto *worker_stream = reinterpret_cast *>(info->stream); - static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { worker_stream->Wait(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e4387b3d218c..9d0703e9b66c 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1188,10 +1188,9 @@ void SetValueOp(const real_t& rhs, NDArray* out) { } else { ndarray::Eval(ctx.get_stream(), rhs, ret); } - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); break; } +#endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } From 19cc3a5168cf1cd10132b63a0d6811559e6937a3 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Wed, 13 Oct 2021 15:15:39 +0900 Subject: [PATCH 7/8] clang-format Signed-off-by: Serge Panev --- include/mxnet/base.h | 4 +- include/mxnet/engine.h | 36 ++++--- include/mxnet/storage.h | 8 +- src/c_api/c_api.cc | 12 +-- src/common/object_pool.h | 4 +- src/engine/engine.cc | 11 +- src/engine/naive_engine.cc | 19 ++-- src/engine/stream_manager.h | 2 +- src/engine/threaded_engine.cc | 114 +++++++++------------ src/engine/threaded_engine.h | 33 +++--- src/engine/threaded_engine_perdevice.cc | 25 +++-- src/engine/threaded_engine_pooled.cc | 14 +-- src/imperative/imperative_utils.h | 38 +++---- src/kvstore/p3store_dist.h | 8 +- src/ndarray/ndarray.cc | 78 +++++++------- src/operator/operator_util.cc | 63 ++++++------ src/resource.cc | 125 ++++++++++++----------- src/storage/pooled_storage_manager.h | 7 +- src/storage/storage.cc | 2 +- tests/cpp/engine/threaded_engine_test.cc | 31 ++++-- 20 files changed, 311 insertions(+), 323 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 61a25cb4738b..b403cd7278f0 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -347,10 +347,10 @@ struct RunContext { * \brief the auxiliary stream of the device, can be nullptr or Stream* in GPU mode */ void *aux_stream; - /*! + /*! * \brief pointer to the cuda event pool used by the dependency engine */ - void *event_pool = nullptr; + void* event_pool = nullptr; /*! * \brief get mshadow stream from Context * \return the mshadow stream diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index ed77bbf97d81..cdb8998d2e83 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -46,8 +46,7 @@ class CUDAEvent final { public: explicit CUDAEvent(Context const& ctx); -CUDAEvent(CUDAEvent&& other) - : event_(other.event_), dev_id_(other.dev_id_) { + CUDAEvent(CUDAEvent&& other) : event_(other.event_), dev_id_(other.dev_id_) { other.event_ = nullptr; } @@ -59,6 +58,7 @@ CUDAEvent(CUDAEvent&& other) inline std::weak_ptr GetEvent() noexcept { return event_; } + private: std::shared_ptr event_; int dev_id_; @@ -84,6 +84,7 @@ class CUDAEventPool final { inline uint64_t GetCounterValue() noexcept { return counter_.load(); } + private: static constexpr size_t kPoolSize = 64; std::vector events_; @@ -155,7 +156,7 @@ class CallbackOnStart { /*! \brief engine can see content of callback */ friend class ::mxnet::Engine; /*! \brief the real callback */ - void (*callback_)(Engine *, void *, const dmlc::Error *); + void (*callback_)(Engine*, void*, const dmlc::Error*); /*! \brief the engine class passed to callback */ Engine* engine_; /*! \brief the parameter set on callback */ @@ -345,7 +346,7 @@ class MXNET_API Engine { * * \return A shared pointer to Engine singleton. */ - static const std::shared_ptr &_GetSharedRef(); + static const std::shared_ptr& _GetSharedRef(); /*! * \brief Push an synchronous operation to the engine. * \param exec_fn Execution function that executes the operation. @@ -364,13 +365,18 @@ class MXNET_API Engine { FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) { - this->PushAsync([exec_fn](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { - on_start(); - exec_fn(ctx); - on_complete(); - }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); + this->PushAsync( + [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { + on_start(); + exec_fn(ctx); + on_complete(); + }, + exec_ctx, + const_vars, + mutable_vars, + prop, + priority, + opr_name); } /*! @@ -378,12 +384,12 @@ class MXNET_API Engine { * \param callback th static callback function. * \param param the paramter passed to callback. */ - inline CallbackOnStart CreateOnStart( - void (*callback)(Engine *, void *, const dmlc::Error *), void *param) { + inline CallbackOnStart CreateOnStart(void (*callback)(Engine*, void*, const dmlc::Error*), + void* param) { CallbackOnStart ret; ret.callback_ = callback; - ret.engine_ = this; - ret.param_ = param; + ret.engine_ = this; + ret.param_ = param; return ret; } diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 46e4e3326790..06db6cecc15b 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -77,9 +77,9 @@ class Storage { std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR}; std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR}; /*! - * \brief Used to pass events back and forth between the engine Var - * and the storage manager. - */ + * \brief Used to pass events back and forth between the engine Var + * and the storage manager. + */ SyncObj sync_obj; }; /*! @@ -154,7 +154,7 @@ class Storage { * * \return A shared pointer to Storage singleton. */ - static const std::shared_ptr &_GetSharedRef(); + static const std::shared_ptr& _GetSharedRef(); private: std::mutex cpu_mutex_; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 4eb9713a8107..45ae9580a61e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3796,19 +3796,17 @@ int MXEnginePushAsync(EngineAsyncFunc async_func, Engine::AsyncFn exec_fn; if (deleter == nullptr) { - exec_fn = [async_func, func_param](RunContext rctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + exec_fn = [async_func, func_param]( + RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { async_func(&rctx, &on_start, &on_complete, func_param); }; } else { // Wrap func_param in a shared_ptr with deleter such that deleter // will be called when the lambda goes out of scope. std::shared_ptr shared_func_param(func_param, deleter); - exec_fn = [async_func, shared_func_param](RunContext rctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { - async_func(&rctx, &on_start,, &on_complete, shared_func_param.get()); + exec_fn = [async_func, shared_func_param]( + RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { + async_func(&rctx, &on_start, , &on_complete, shared_func_param.get()); }; } diff --git a/src/common/object_pool.h b/src/common/object_pool.h index 023e9d6f2df2..66385b9ade64 100644 --- a/src/common/object_pool.h +++ b/src/common/object_pool.h @@ -61,7 +61,7 @@ class ObjectPool { * \brief Get a shared ptr of the singleton instance of pool. * \return Shared pointer to the Object Pool. */ - static const std::shared_ptr &_GetSharedRef(); + static const std::shared_ptr& _GetSharedRef(); private: /*! @@ -170,7 +170,7 @@ ObjectPool* ObjectPool::Get() { } template -const std::shared_ptr > &ObjectPool::_GetSharedRef() { +const std::shared_ptr >& ObjectPool::_GetSharedRef() { static std::shared_ptr > inst_ptr(new ObjectPool()); return inst_ptr; } diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 429e547ca915..2e1e0500ef82 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -38,9 +38,8 @@ inline Engine* CreateEngine() { // The async tag is used later to determine if we use the GPU dependecy engine std::string async_engine_tag = "Async"; - auto tag_pos = stype.find(async_engine_tag); - if (tag_pos != std::string::npos - && tag_pos + async_engine_tag.length() == stype.length()) { + auto tag_pos = stype.find(async_engine_tag); + if (tag_pos != std::string::npos && tag_pos + async_engine_tag.length() == stype.length()) { stype = stype.substr(0, tag_pos); } @@ -67,8 +66,8 @@ inline Engine* CreateEngine() { } #if MXNET_USE_CUDA -CUDAEvent::CUDAEvent(Context const& ctx) : - event_(std::make_shared()), dev_id_(ctx.dev_id) { +CUDAEvent::CUDAEvent(Context const& ctx) + : event_(std::make_shared()), dev_id_(ctx.dev_id) { cudaEvent_t ev; common::cuda::DeviceStore device_store(dev_id_); CUDA_CALL(cudaEventCreateWithFlags(&ev, cudaEventDisableTiming)); @@ -85,7 +84,7 @@ CUDAEvent::~CUDAEvent() { #endif } // namespace engine -const std::shared_ptr &Engine::_GetSharedRef() { +const std::shared_ptr& Engine::_GetSharedRef() { static std::shared_ptr sptr(engine::CreateEngine()); return sptr; } diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 76dd04249649..ad24af1dabe9 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -118,9 +118,7 @@ class NaiveEngine final : public Engine { NaiveOpr* opr = op->Cast(); opr->profiling = profiling && profiler->IsProfiling(profiler::Profiler::kSymbolic); this->PushAsync( - [&](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [&](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { if (opr->profiling) { std::unique_ptr attrs; if (profiler->AggregateEnabled()) { @@ -158,7 +156,7 @@ class NaiveEngine final : public Engine { bool wait = false) override { std::promise promise; std::future future = promise.get_future(); - CallbackOnStart on_start = CreateOnStart(NaiveEngine::OnStart, &promise); + CallbackOnStart on_start = CreateOnStart(NaiveEngine::OnStart, &promise); CallbackOnComplete callback = CreateCallback(NaiveEngine::OnComplete, &promise); profiler::Profiler* profiler = profiler::Profiler::Get(); auto opr_deleter = [this](NaiveOpr* p) { this->DeleteOperator(p); }; @@ -192,9 +190,7 @@ class NaiveEngine final : public Engine { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]); } - exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id]}, - on_start, - callback); + exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id]}, on_start, callback); #else LOG(FATAL) << "GPU is not enabled"; #endif @@ -214,9 +210,8 @@ class NaiveEngine final : public Engine { void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { NaiveVar* naive_var = NaiveVar::CastFromBase(var); this->PushAsync( - [delete_fn, naive_var](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) mutable { + [delete_fn, naive_var]( + RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) mutable { on_start(); delete_fn(ctx); NaiveVar::Delete(naive_var); @@ -242,9 +237,7 @@ class NaiveEngine final : public Engine { private: // onstart - static void OnStart(Engine *engine, void *param, - const dmlc::Error* error) { - } + static void OnStart(Engine* engine, void* param, const dmlc::Error* error) {} // callback to oncomplete static void OnComplete(Engine* engine, void* param, const dmlc::Error* error) { static_cast*>(param)->set_value(); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index f0cfb19fa7de..2384e1f19ec2 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -91,7 +91,7 @@ RunContext StreamManager::GetRunContext(Context const& ctx) if (event_pools_.at(ctx.dev_id) == nullptr) { event_pools_[ctx.dev_id] = std::make_unique(ctx); } - event_pool = event_pools_.at(ctx.dev_id).get(); + event_pool = event_pools_.at(ctx.dev_id).get(); use_counter = counter; counter = (counter + 1) % kStreams; } diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 2308777ab2fa..40d852b83b86 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -272,9 +272,7 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end()); deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end()); this->PushAsync( - [threaded_opr](RunContext, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [threaded_opr](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) { on_start(); ThreadedOpr::Delete(threaded_opr); on_complete(); @@ -352,9 +350,7 @@ void ThreadedEngine::PushSync(SyncFn exec_fn, const char* opr_name) { if (!bulk_size() || prop != FnProperty::kNormal || priority) { this->PushAsync( - [exec_fn](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { on_start(); exec_fn(ctx); on_complete(); @@ -377,9 +373,8 @@ void ThreadedEngine::PushSync(SyncFn exec_fn, void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) { ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); this->PushAsync( - [delete_fn, threaded_var](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [delete_fn, threaded_var]( + RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { // Mark variable as orphan, // so during `ThreadedEngine::OnComplete` it could be recycled. on_start(); @@ -408,9 +403,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) { } std::atomic done{false}; this->PushAsync( - [this, &done](RunContext, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [this, &done](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) { on_start(); if (engine_info_) { LOG(INFO) << "Sync is executed"; @@ -553,15 +546,13 @@ void ThreadedEngine::OnCompleteStatic(Engine* engine, void* opr_block_, const dm OprBlock::Delete(opr_block); } -void ThreadedEngine::OnStartStatic(Engine *engine, void *opr_block, - const dmlc::Error* error) { +void ThreadedEngine::OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error) { // no-op } #if MXNET_USE_CUDA -static inline void AddEventHelper( - std::unordered_map* events_per_stream, - const EventInfo& cuda_event) { +static inline void AddEventHelper(std::unordered_map* events_per_stream, + const EventInfo& cuda_event) { auto event_stream = cuda_event.stream; if (events_per_stream->count(event_stream) > 0) { if ((*events_per_stream)[event_stream].pool_index < cuda_event.pool_index) { @@ -579,23 +570,22 @@ static inline bool IsEngineAsync() { return tag_pos != std::string::npos; } -void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, - const dmlc::Error* error) { +void ThreadedEngine::OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error) { static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { return; } - ThreadedOpr *threaded_opr = static_cast(opr_block)->opr; + ThreadedOpr* threaded_opr = static_cast(opr_block)->opr; std::unordered_map event_per_stream; for (auto* read_var : threaded_opr->const_vars) { - auto &sync_obj = read_var->sync_object; + auto& sync_obj = read_var->sync_object; std::lock_guard l(sync_obj.mutex); - auto &reader_events = sync_obj.reader_events; + auto& reader_events = sync_obj.reader_events; // check for expired events and delete them - reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), - [&](const EventInfo e_i) { - return e_i.event.expired(); - }), reader_events.end()); + reader_events.erase(std::remove_if(reader_events.begin(), + reader_events.end(), + [&](const EventInfo e_i) { return e_i.event.expired(); }), + reader_events.end()); for (auto& cuda_event : reader_events) { AddEventHelper(&event_per_stream, cuda_event); } @@ -609,16 +599,16 @@ void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, } for (auto* write_var : threaded_opr->mutable_vars) { - auto &sync_obj = write_var->sync_object; + auto& sync_obj = write_var->sync_object; std::lock_guard l(sync_obj.mutex); - auto &reader_events = sync_obj.reader_events; + auto& reader_events = sync_obj.reader_events; // check for expired events and delete them - reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), - [&](const EventInfo e_i) { - return e_i.event.expired(); - }), reader_events.end()); + reader_events.erase(std::remove_if(reader_events.begin(), + reader_events.end(), + [&](const EventInfo e_i) { return e_i.event.expired(); }), + reader_events.end()); for (auto& cuda_event : reader_events) { - AddEventHelper(&event_per_stream, cuda_event); + AddEventHelper(&event_per_stream, cuda_event); } if (!sync_obj.writer_event.empty()) { if (sync_obj.writer_event[0].event.expired()) { @@ -634,26 +624,25 @@ void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block, } } -void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info, - const dmlc::Error* error) { +void ThreadedEngine::OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error) { static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { return; } - auto *info = reinterpret_cast(sync_info); + auto* info = reinterpret_cast(sync_info); CHECK(info->stream != nullptr); - auto *worker_stream = reinterpret_cast *>(info->stream); - ThreadedOpr *threaded_opr = static_cast(info->opr_block)->opr; + auto* worker_stream = reinterpret_cast*>(info->stream); + ThreadedOpr* threaded_opr = static_cast(info->opr_block)->opr; std::unordered_map event_per_stream; for (auto* read_var : threaded_opr->const_vars) { - auto &sync_obj = read_var->sync_object; + auto& sync_obj = read_var->sync_object; std::lock_guard l(sync_obj.mutex); - auto &reader_events = sync_obj.reader_events; + auto& reader_events = sync_obj.reader_events; // check for expired events and delete them - reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), - [&](const EventInfo e_i) { - return e_i.event.expired(); - }), reader_events.end()); + reader_events.erase(std::remove_if(reader_events.begin(), + reader_events.end(), + [&](const EventInfo e_i) { return e_i.event.expired(); }), + reader_events.end()); for (auto& writer : sync_obj.writer_event) { if (writer.event.expired()) { sync_obj.writer_event.clear(); @@ -671,21 +660,20 @@ void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info, } } if (!found) { - AddEventHelper(&event_per_stream, - writer); + AddEventHelper(&event_per_stream, writer); } } } } for (auto* write_var : threaded_opr->mutable_vars) { - auto &sync_obj = write_var->sync_object; + auto& sync_obj = write_var->sync_object; std::lock_guard l(sync_obj.mutex); // check for expired events and delete them - auto &reader_events = sync_obj.reader_events; - reader_events.erase(std::remove_if(reader_events.begin(), reader_events.end(), - [&](const EventInfo e_i) { - return e_i.event.expired(); - }), reader_events.end()); + auto& reader_events = sync_obj.reader_events; + reader_events.erase(std::remove_if(reader_events.begin(), + reader_events.end(), + [&](const EventInfo e_i) { return e_i.event.expired(); }), + reader_events.end()); // if there are some readers, we wait for them for (auto& cuda_event : reader_events) { if (worker_stream->stream_ != cuda_event.stream) { @@ -708,12 +696,11 @@ void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info, } } -void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, - const dmlc::Error* error) { - auto *info = reinterpret_cast(sync_info); +void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error) { + auto* info = reinterpret_cast(sync_info); CHECK(info->stream != nullptr); - auto *worker_stream = reinterpret_cast *>(info->stream); + auto* worker_stream = reinterpret_cast*>(info->stream); static bool use_new_dep_engine = IsEngineAsync(); if (!use_new_dep_engine) { @@ -723,13 +710,13 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, return; } - ThreadedOpr *threaded_opr = static_cast(info->opr_block)->opr; - auto* event_pool = static_cast(info->event_pool); - auto[event, event_pool_idx] = event_pool->GetNextEvent(); - auto ev = event.lock(); + ThreadedOpr* threaded_opr = static_cast(info->opr_block)->opr; + auto* event_pool = static_cast(info->event_pool); + auto [event, event_pool_idx] = event_pool->GetNextEvent(); + auto ev = event.lock(); MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_)); for (auto* read_var : threaded_opr->const_vars) { - auto &sync_obj = read_var->sync_object; + auto& sync_obj = read_var->sync_object; std::lock_guard l(sync_obj.mutex); // If some reader event is already recorded on the same stream, // we want to replace ourselves by it @@ -737,7 +724,7 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, for (i = 0; i < sync_obj.reader_events.size(); ++i) { auto stream = sync_obj.reader_events[i].stream; if (stream == worker_stream->stream_) { - sync_obj.reader_events[i].event = event; + sync_obj.reader_events[i].event = event; sync_obj.reader_events[i].pool_index = event_pool_idx; break; } @@ -748,7 +735,7 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, } for (auto* write_var : threaded_opr->mutable_vars) { - auto &sync_obj = write_var->sync_object; + auto& sync_obj = write_var->sync_object; std::lock_guard l(sync_obj.mutex); sync_obj.reader_events.clear(); sync_obj.writer_event.clear(); @@ -760,6 +747,5 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info, } #endif - } // namespace engine } // namespace mxnet diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 1848e42a4e00..a9e08a80aadc 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -353,8 +353,10 @@ class ThreadedEngine : public Engine { * \param run_ctx runtime context used to execute the function. * \param opr_block the opr_block to be executed and deleted. */ - void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block, - CallbackOnStart on_start, CallbackOnComplete callback) { + void ExecuteOprBlock(RunContext run_ctx, + OprBlock* opr_block, + CallbackOnStart on_start, + CallbackOnComplete callback) { ThreadedOpr* threaded_opr = opr_block->opr; if (opr_block->profiling && threaded_opr->opr_name.size()) { std::unique_ptr attrs; @@ -433,24 +435,19 @@ class ThreadedEngine : public Engine { } protected: - static void OnStartStatic(Engine *engine, void *opr_block, - const dmlc::Error* error); - static void OnCompleteStatic(Engine *engine, void *threaded_opr, - const dmlc::Error* error); + static void OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error); + static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error); #if MXNET_USE_CUDA - static void OnStartCPU(Engine *engine, void *opr_block, - const dmlc::Error* error); - static void OnStartGPU(Engine *engine, void *sync_info, - const dmlc::Error* error); - static void OnCompleteGPU(Engine *engine, void *sync_info, - const dmlc::Error* error); + static void OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error); + static void OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error); + static void OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error); struct GPUWorkerSyncInfo : public common::ObjectPoolAllocatable { - void *opr_block{nullptr}; - void *stream{nullptr}; - void *event_pool{nullptr}; + void* opr_block{nullptr}; + void* stream{nullptr}; + void* event_pool{nullptr}; }; - std::shared_ptr > objpool_gpu_sync_ref_; + std::shared_ptr> objpool_gpu_sync_ref_; #endif private: @@ -559,9 +556,7 @@ class ThreadedEngine : public Engine { DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars); auto functions = bulk_status.functions; this->PushAsync( - [functions](RunContext ctx, - CallbackOnStart on_start, - CallbackOnComplete on_complete) { + [functions](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { on_start(); for (auto& fn : *functions) { fn(ctx); diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 15aa60073ca1..33408a7bffed 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -122,10 +122,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine { MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); #endif } - CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, - opr_block); - CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, - opr_block); + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block); + CallbackOnComplete callback = + this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr}, opr_block, on_start, callback); } else { if (ctx.dev_mask() == Context::kCPU) { @@ -297,8 +296,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { if (event_pool_it != cuda_event_pool_per_worker_.end()) { event_pool = event_pool_it->second.get(); } else { - auto res = cuda_event_pool_per_worker_.emplace(ctx.dev_id, - std::make_unique(ctx)); + auto res = + cuda_event_pool_per_worker_.emplace(ctx.dev_id, std::make_unique(ctx)); event_pool = res.first->second.get(); } // execute task @@ -324,11 +323,11 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #if MXNET_USE_NVTX common::cuda::nvtx::gpuRangeStop(); #endif - auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); - info->opr_block = opr_block; - info->stream = stream; - info->event_pool = event_pool; - CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); + auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); + info->opr_block = opr_block; + info->stream = stream; + info->event_pool = event_pool; + CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback); } @@ -361,8 +360,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #else CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block); #endif - CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, - opr_block); + CallbackOnComplete callback = + this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback); } } diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 2d9183667ee4..0ec91b23e260 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -164,14 +164,14 @@ class ThreadedEnginePooled : public ThreadedEngine { callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); } else { CHECK_EQ(opr_block->ctx.dev_mask(), Context::kGPU); - auto stream = rctx.get_stream(); - auto event_pool = static_cast(rctx.event_pool); - auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); - info->opr_block = opr_block; - info->stream = stream; + auto stream = rctx.get_stream(); + auto event_pool = static_cast(rctx.event_pool); + auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); + info->opr_block = opr_block; + info->stream = stream; info->event_pool = event_pool; - on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); - callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); + on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); + callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); } #else // MXNET_USE_CUDA CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 8f71417851c5..96bff8e2597c 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -783,15 +783,16 @@ inline void PushOperator(const OpStatePtr& state, RunContext rctx{ctx, nullptr, nullptr}; run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, - engine::CallbackOnStart(), - engine::CallbackOnComplete()); }, - ctx, - read_vars, - write_vars, - FnProperty::kNormal, - 0, - op->name.c_str()); + Engine::Get()->PushSync( + [=](RunContext rctx) { + run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); + }, + ctx, + read_vars, + write_vars, + FnProperty::kNormal, + 0, + op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); Engine::Get()->PushAsync( @@ -844,15 +845,16 @@ inline void PushOperator(const OpStatePtr& state, RunContext rctx{ctx, nullptr}; run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, - engine::CallbackOnStart(), - engine::CallbackOnComplete()); }, - ctx, - read_vars, - write_vars, - FnProperty::kNormal, - 0, - op->name.c_str()); + Engine::Get()->PushSync( + [=](RunContext rctx) { + run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); + }, + ctx, + read_vars, + write_vars, + FnProperty::kNormal, + 0, + op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); Engine::Get()->PushAsync( diff --git a/src/kvstore/p3store_dist.h b/src/kvstore/p3store_dist.h index ed3875f96b45..56912cd7abcf 100644 --- a/src/kvstore/p3store_dist.h +++ b/src/kvstore/p3store_dist.h @@ -77,10 +77,10 @@ class P3StoreDist : public KVStoreDist { LOG(FATAL) << "NotImplementedError: PushCompressed not implemented in P3StoreDist."; } - void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) override { - auto push_to_servers = [this, key, pskv, send_buf, priority] (RunContext rctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete cb) { + void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) override { + auto push_to_servers = [this, key, pskv, send_buf, priority](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete cb) { on_start(); const int dtype = send_buf.dtype(); // convert to ps keys diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 9d0703e9b66c..f67aa906fd95 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -136,7 +136,7 @@ NDArray::Chunk::~Chunk() { engine->DeleteVariable( [mem, skip_free, var = this->var](RunContext s) mutable { #if MXNET_USE_CUDA - auto &sync_obj = var->sync_object; + auto& sync_obj = var->sync_object; Storage::SyncObj storage_sync_obj; { std::lock_guard l(sync_obj.mutex); @@ -761,19 +761,19 @@ void NDArray::Reorder2DefaultAsync() const { std::vector mutable_vars(1, this->var()); NDArray tmp = *this; Engine::Get()->PushAsync( - [tmp](RunContext ctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete on_complete) { - on_start(); - tmp.ptr_->Reorder2Default(); - on_complete(); - }, - ctx(), - const_vars, - mutable_vars, - FnProperty::kNormal, - 0, - "Reorder2Default"); + [tmp](RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + tmp.ptr_->Reorder2Default(); + on_complete(); + }, + ctx(), + const_vars, + mutable_vars, + FnProperty::kNormal, + 0, + "Reorder2Default"); } // now just support bf16->fp32 @@ -797,22 +797,22 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const { const auto version = this->version(); Engine::Get()->PushAsync( [tmp, version, desc](RunContext ctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete on_complete) { - on_start(); - // MXNet will try to reuse NDArray from memory planning, so we need to ensure - // the NDArray is still holding the original trunk data. - if (tmp.version() == version) { - tmp.ptr_->MKLDNNDataReorder(desc); - } - on_complete(); - }, - ctx(), - const_vars, - mutable_vars, - FnProperty::kNormal, - 0, - "Reorder"); + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + // MXNet will try to reuse NDArray from memory planning, so we need to ensure + // the NDArray is still holding the original trunk data. + if (tmp.version() == version) { + tmp.ptr_->MKLDNNDataReorder(desc); + } + on_complete(); + }, + ctx(), + const_vars, + mutable_vars, + FnProperty::kNormal, + 0, + "Reorder"); } const mkldnn::memory* NDArray::GetMKLDNNData() const { @@ -2204,8 +2204,7 @@ void NDArray::SyncCopyFromCPU(const void* data, size_t size) const { Engine::CallbackOnComplete on_complete) { on_start(); TBlob dst = this->data(); - ndarray::Copy(src, &dst, - Context::CPU(), this->ctx(), rctx); + ndarray::Copy(src, &dst, Context::CPU(), this->ctx(), rctx); on_complete(); }, this->ctx(), @@ -2286,7 +2285,7 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { Engine::CallbackOnComplete on_complete) { on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); on_complete(); }, @@ -2303,7 +2302,7 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { Engine::CallbackOnComplete on_complete) { on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); on_complete(); }, @@ -2320,7 +2319,7 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { Engine::CallbackOnComplete on_complete) { on_start(); const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data()); - TBlob dst_data = get_dst_data(src_data.shape_); + TBlob dst_data = get_dst_data(src_data.shape_); ndarray::Copy(src_data, &dst_data, src.ctx(), this->ctx(), rctx); on_complete(); }, @@ -2380,13 +2379,13 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const { Engine::CallbackOnComplete on_complete) { on_start(); { - auto var = this->var(); + auto var = this->var(); auto& sync_obj = var->sync_object; std::lock_guard lock{sync_obj.mutex}; bool has_writer = false; std::shared_ptr w_ev_ptr; if (!sync_obj.writer_event.empty()) { - w_ev_ptr = sync_obj.writer_event[0].event.lock(); + w_ev_ptr = sync_obj.writer_event[0].event.lock(); has_writer = w_ev_ptr ? true : false; } for (auto ev : sync_obj.reader_events) { @@ -2398,7 +2397,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const { if (has_writer) { auto w_ev = sync_obj.writer_event[0]; if (w_ev.stream == ev.stream) { - event = w_ev.pool_index > ev.pool_index ? *w_ev_ptr : *event_ptr; + event = w_ev.pool_index > ev.pool_index ? *w_ev_ptr : *event_ptr; has_writer = false; } } @@ -2408,8 +2407,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const { CUDA_CALL(cudaEventSynchronize(*w_ev_ptr)); } } - ndarray::Copy(this->data(), &dst, - this->ctx(), Context::CPU(), rctx); + ndarray::Copy(this->data(), &dst, this->ctx(), Context::CPU(), rctx); on_complete(); }, this->ctx(), diff --git a/src/operator/operator_util.cc b/src/operator/operator_util.cc index 8235827a91ee..b2277b3d00ac 100644 --- a/src/operator/operator_util.cc +++ b/src/operator/operator_util.cc @@ -486,16 +486,17 @@ void SimpleOpRegEntryImpl::RegisterSourceImperative() { SourceFunction fun = fsource_[dev_mask]; OpReqType req = kWriteTo; - Engine::Get()->PushSync([ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(env, &tmp, req, ctx); - }, - ret.ctx(), - {}, - write_vars, - FnProperty::kNormal, - 0, - "RegisterSourceImperative"); + Engine::Get()->PushSync( + [ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(env, &tmp, req, ctx); + }, + ret.ctx(), + {}, + write_vars, + FnProperty::kNormal, + 0, + "RegisterSourceImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(0).set_num_mutate_vars(1); @@ -662,16 +663,17 @@ void SimpleOpRegEntryImpl::RegisterUnaryImperative() { << "inplace operation is not enabled for operator " << name; } - Engine::Get()->PushSync([src, ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(src.data(), env, &tmp, req, ctx); - }, - src.ctx(), - const_vars, - write_vars, - FnProperty::kNormal, - 0, - "RegisterUnaryImperative"); + Engine::Get()->PushSync( + [src, ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(src.data(), env, &tmp, req, ctx); + }, + src.ctx(), + const_vars, + write_vars, + FnProperty::kNormal, + 0, + "RegisterUnaryImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(1).set_num_mutate_vars(1); @@ -938,16 +940,17 @@ void SimpleOpRegEntryImpl::RegisterBinaryImperative() { << " warning, perform inplace operation with right operand, may not be supported"; } - Engine::Get()->PushSync([lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) { - TBlob tmp = ret.data(); - (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx); - }, - lhs.ctx(), - const_vars, - write_vars, - FnProperty::kNormal, - 0, - "RegisterBinaryImperative"); + Engine::Get()->PushSync( + [lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) { + TBlob tmp = ret.data(); + (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx); + }, + lhs.ctx(), + const_vars, + write_vars, + FnProperty::kNormal, + 0, + "RegisterBinaryImperative"); }; // register the function. NDArrayReg().set_body(body).set_num_use_vars(2).set_num_mutate_vars(1); diff --git a/src/resource.cc b/src/resource.cc index 7c434c80abf0..010481f322d3 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -266,20 +266,20 @@ class ResourceManagerImpl : public ResourceManager { inline void Seed(uint32_t seed) { mshadow::Random* r = prnd; Engine::Get()->PushAsync( - [r, seed](RunContext rctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete on_complete) { - on_start(); - r->set_stream(rctx.get_stream()); - r->Seed(seed); - on_complete(); - }, - ctx, - {}, - {resource.var}, - FnProperty::kNormal, - 0, - "ResourceRandomSetSeed"); + [r, seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + r->set_stream(rctx.get_stream()); + r->Seed(seed); + on_complete(); + }, + ctx, + {}, + {resource.var}, + FnProperty::kNormal, + 0, + "ResourceRandomSetSeed"); } }; @@ -344,41 +344,42 @@ class ResourceManagerImpl : public ResourceManager { uint32_t current_seed = p->ctx.dev_id + i * kMaxNumGPUs + seed * kRandMagic; Resource* r = &(p->resource[i]); Engine::Get()->PushAsync( - [r, current_seed](RunContext rctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete on_complete) { - on_start(); - auto state_space = static_cast(r->ptr_); - mshadow::Stream* stream = rctx.get_stream(); - CHECK_EQ(state_space->ctx.dev_id, stream->dev_id) - << "The device id of cudnn dropout state space doesn't match that from stream."; - if (!state_space->handle.size) { - // not allocated yet - size_t dropout_state_size; - CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); - // reserve GPU space - Storage::Get()->DirectFree( - Storage::Get()->Alloc(dropout_state_size, state_space->ctx)); - state_space->GetSpace(dropout_state_size, "cudnn_dropout_state"); - } - cudnnDropoutDescriptor_t temp_descriptor; - CUDNN_CALL(cudnnCreateDropoutDescriptor(&temp_descriptor)); - CUDNN_CALL(cudnnSetDropoutDescriptor(temp_descriptor, stream->dnn_handle_, - 0.5, - state_space->handle.dptr, - state_space->handle.size, - current_seed)); - CUDNN_CALL(cudnnDestroyDropoutDescriptor(temp_descriptor)); - cudaStream_t cuda_stream = mshadow::Stream::GetStream(stream); - cudaStreamSynchronize(cuda_stream); - on_complete(); - }, - p->ctx, - {}, - {r->var}, - FnProperty::kNormal, - 0, - "CUDNNDropoutDescriptorSeed"); + [r, current_seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + auto state_space = static_cast(r->ptr_); + mshadow::Stream* stream = rctx.get_stream(); + CHECK_EQ(state_space->ctx.dev_id, stream->dev_id) + << "The device id of cudnn dropout state space doesn't match that from stream."; + if (!state_space->handle.size) { + // not allocated yet + size_t dropout_state_size; + CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); + // reserve GPU space + Storage::Get()->DirectFree( + Storage::Get()->Alloc(dropout_state_size, state_space->ctx)); + state_space->GetSpace(dropout_state_size, "cudnn_dropout_state"); + } + cudnnDropoutDescriptor_t temp_descriptor; + CUDNN_CALL(cudnnCreateDropoutDescriptor(&temp_descriptor)); + CUDNN_CALL(cudnnSetDropoutDescriptor(temp_descriptor, + stream->dnn_handle_, + 0.5, + state_space->handle.dptr, + state_space->handle.size, + current_seed)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(temp_descriptor)); + cudaStream_t cuda_stream = mshadow::Stream::GetStream(stream); + cudaStreamSynchronize(cuda_stream); + on_complete(); + }, + p->ctx, + {}, + {r->var}, + FnProperty::kNormal, + 0, + "CUDNNDropoutDescriptorSeed"); } p->curr_ptr.store(0); @@ -453,19 +454,19 @@ class ResourceManagerImpl : public ResourceManager { inline void SeedOne(size_t i, uint32_t seed) { common::random::RandGenerator* r = sampler[i]; Engine::Get()->PushAsync( - [r, seed](RunContext rctx, - Engine::CallbackOnStart on_start, - Engine::CallbackOnComplete on_complete) { - on_start(); - r->Seed(rctx.get_stream(), seed); - on_complete(); - }, - ctx, - {}, - {resource[i].var}, - FnProperty::kNormal, - 0, - "ResourceNativeRandomSetSeed"); + [r, seed](RunContext rctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) { + on_start(); + r->Seed(rctx.get_stream(), seed); + on_complete(); + }, + ctx, + {}, + {resource[i].var}, + FnProperty::kNormal, + 0, + "ResourceNativeRandomSetSeed"); } // get next resource in round roubin matter inline Resource GetNext() { diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index fff4549e837b..0afff3241f43 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -130,9 +130,8 @@ class PooledStorageManager : public StorageManager, public BucketingStrategy, pu void Free(Storage::Handle handle) override { // Insert returned memory in cache std::lock_guard lock(Storage::Get()->GetMutex(dev_type_)); - StoringMethod::InsertInCache(BucketingStrategy::get_bucket(handle.size), - handle.dptr, - handle.sync_obj); + StoringMethod::InsertInCache( + BucketingStrategy::get_bucket(handle.size), handle.dptr, handle.sync_obj); } void DirectFree(Storage::Handle handle) override { @@ -208,7 +207,7 @@ void PooledStorageManager::Alloc(Storage::Hand } else { // Reusing memory auto ptr_syncobj = reuse_pool->back(); - handle->dptr = ptr_syncobj.first; + handle->dptr = ptr_syncobj.first; if (dev_type_ == Context::kGPU) { handle->sync_obj = ptr_syncobj.second; #if MXNET_USE_CUDA diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 90aa302b1476..d11fde26a624 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -256,7 +256,7 @@ const std::string env_var_name(const char* dev_type, env_var_type type) { } // namespace storage -const std::shared_ptr &Storage::_GetSharedRef() { +const std::shared_ptr& Storage::_GetSharedRef() { #ifdef __MXNET_JS__ // dummy code needed for emscripten code to pass // do not know why, the new will be NULLPTR diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 49a5abaed6ed..465e387b8d42 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -113,7 +113,9 @@ double EvaluateWorkloads(const std::vector& workloads, auto func = [wl, data](RunContext ctx, Engine::CallbackOnStart on_start, Engine::CallbackOnComplete cb) { - on_start(); EvaluateWorkload(wl, data); cb(); + on_start(); + EvaluateWorkload(wl, data); + cb(); }; std::vector reads; for (auto i : wl.reads) { @@ -356,7 +358,8 @@ TEST(Engine, basics) { std::this_thread::sleep_for(std::chrono::seconds{1}); cb(); }, - {var}, {})); + {var}, + {})); engine->Push(oprs.at(i), mxnet::Context{}); } engine->WaitForAll(); @@ -381,7 +384,8 @@ TEST(Engine, basics) { std::this_thread::sleep_for(std::chrono::milliseconds{500}); cb(); }, - {}, {var})); + {}, + {var})); engine->Push(oprs.at(i), mxnet::Context{}); } // std::this_thread::sleep_for(std::chrono::seconds{1}); @@ -402,7 +406,7 @@ TEST(Engine, basics) { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, + [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnStart on_start, mxnet::Engine::CallbackOnComplete cb) { std::this_thread::sleep_for(std::chrono::seconds{2}); @@ -410,7 +414,9 @@ TEST(Engine, basics) { Foo(ctx, 42); cb(); }, - {}, {var}, mxnet::FnProperty::kCopyFromGPU)); + {}, + {var}, + mxnet::FnProperty::kCopyFromGPU)); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; engine->WaitForVar(var); @@ -425,15 +431,16 @@ TEST(Engine, basics) { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, - mxnet::Engine::CallbackOnStart on_start, - mxnet::Engine::CallbackOnComplete cb) { + [](mxnet::RunContext ctx, + mxnet::Engine::CallbackOnStart on_start, + mxnet::Engine::CallbackOnComplete cb) { on_start(); Foo(ctx, 42); std::this_thread::sleep_for(std::chrono::seconds{2}); cb(); }, - {var}, {})); + {var}, + {})); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "Operator pushed, should not wait."; engine->WaitForVar(var); @@ -473,7 +480,8 @@ TEST(Engine, VarVersion) { Foo(ctx, i); cb(); }, - {var}, {})); + {var}, + {})); engine->Push(oprs.at(i), mxnet::Context{}); } engine->WaitForAll(); @@ -497,7 +505,8 @@ TEST(Engine, VarVersion) { Foo(ctx, i); cb(); }, - {}, {var})); + {}, + {var})); engine->Push(oprs.at(i), mxnet::Context{}); } engine->WaitForAll(); From b1743570170b5fb8a045f5d19138a2ce69c1c2f4 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 14 Oct 2021 10:25:04 +0900 Subject: [PATCH 8/8] Fix rebase errors Signed-off-by: Serge Panev --- src/c_api/c_api.cc | 2 +- src/engine/threaded_engine_perdevice.cc | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 45ae9580a61e..736f71622850 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3806,7 +3806,7 @@ int MXEnginePushAsync(EngineAsyncFunc async_func, std::shared_ptr shared_func_param(func_param, deleter); exec_fn = [async_func, shared_func_param]( RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) { - async_func(&rctx, &on_start, , &on_complete, shared_func_param.get()); + async_func(&rctx, &on_start, &on_complete, shared_func_param.get()); }; } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 33408a7bffed..b566e4417a41 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -318,10 +318,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine { : nvtx_name.size(); auto color = common::cuda::nvtx::nameToColor(nvtx_name, name_prefix_len); common::cuda::nvtx::gpuRangeStart(color, nvtx_name); -#endif - this->ExecuteOprBlock(run_ctx, opr_block); -#if MXNET_USE_NVTX - common::cuda::nvtx::gpuRangeStop(); #endif auto* info = ThreadedEngine::GPUWorkerSyncInfo::New(); info->opr_block = opr_block; @@ -330,6 +326,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine { CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartGPU, info); CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info); this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback); +#if MXNET_USE_NVTX + common::cuda::nvtx::gpuRangeStop(); +#endif } #else ready_event->signal();