From 04529efed409b4e04cb9e86d0988ab754ef994f0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 13 Sep 2021 11:09:48 -0700 Subject: [PATCH 1/2] fix race condition in limiting resource adapter --- .../mr/device/limiting_resource_adaptor.hpp | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/include/rmm/mr/device/limiting_resource_adaptor.hpp b/include/rmm/mr/device/limiting_resource_adaptor.hpp index 5002962d5..6f00b937f 100644 --- a/include/rmm/mr/device/limiting_resource_adaptor.hpp +++ b/include/rmm/mr/device/limiting_resource_adaptor.hpp @@ -21,8 +21,7 @@ #include -namespace rmm { -namespace mr { +namespace rmm::mr { /** * @brief Resource that uses `Upstream` to allocate memory and limits the total * allocations possible. @@ -59,12 +58,12 @@ class limiting_resource_adaptor final : public device_memory_resource { RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer."); } - limiting_resource_adaptor() = delete; - ~limiting_resource_adaptor() = default; - limiting_resource_adaptor(limiting_resource_adaptor const&) = delete; - limiting_resource_adaptor(limiting_resource_adaptor&&) = default; + limiting_resource_adaptor() = delete; + ~limiting_resource_adaptor() override = default; + limiting_resource_adaptor(limiting_resource_adaptor const&) = delete; + limiting_resource_adaptor(limiting_resource_adaptor&&) noexcept = default; limiting_resource_adaptor& operator=(limiting_resource_adaptor const&) = delete; - limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) = default; + limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) noexcept = default; /** * @brief Return pointer to the upstream resource. @@ -79,14 +78,17 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return true The upstream resource supports streams * @return false The upstream resource does not support streams. */ - bool supports_streams() const noexcept override { return upstream_->supports_streams(); } + [[nodiscard]] bool supports_streams() const noexcept override + { + return upstream_->supports_streams(); + } /** * @brief Query whether the resource supports the get_mem_info API. * * @return bool true if the upstream resource supports get_mem_info, false otherwise. */ - bool supports_get_mem_info() const noexcept override + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return upstream_->supports_get_mem_info(); } @@ -100,7 +102,7 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return std::size_t number of bytes that have been allocated through this * allocator. */ - std::size_t get_allocated_bytes() const { return allocated_bytes_; } + [[nodiscard]] std::size_t get_allocated_bytes() const { return allocated_bytes_; } /** * @brief Query the maximum number of bytes that this allocator is allowed @@ -109,7 +111,7 @@ class limiting_resource_adaptor final : public device_memory_resource { * * @return std::size_t max number of bytes allowed for this allocator */ - std::size_t get_allocation_limit() const { return allocation_limit_; } + [[nodiscard]] std::size_t get_allocation_limit() const { return allocation_limit_; } private: /** @@ -130,11 +132,12 @@ class limiting_resource_adaptor final : public device_memory_resource { void* p = nullptr; std::size_t proposed_size = rmm::detail::align_up(bytes, allocation_alignment_); - if (proposed_size + allocated_bytes_ <= allocation_limit_) { + allocated_bytes_ += proposed_size; + if (allocated_bytes_ <= allocation_limit_) { p = upstream_->allocate(bytes, stream); - allocated_bytes_ += proposed_size; } else { - throw rmm::bad_alloc{"Exceeded memory limit"}; + allocated_bytes_ -= proposed_size; + RMM_FAIL("Exceeded memory limit", rmm::bad_alloc); } return p; @@ -165,13 +168,12 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return true If the two resources are equivalent * @return false If the two resources are not equal */ - bool do_is_equal(device_memory_resource const& other) const noexcept override + [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override { if (this == &other) return true; else { - limiting_resource_adaptor const* cast = - dynamic_cast const*>(&other); + auto const* cast = dynamic_cast const*>(&other); if (cast != nullptr) return upstream_->is_equal(*cast->get_upstream()); else @@ -187,7 +189,8 @@ class limiting_resource_adaptor final : public device_memory_resource { * @param stream Stream on which to get the mem info. * @return std::pair contaiing free_size and total_size of memory */ - std::pair do_get_mem_info(cuda_stream_view stream) const override + [[nodiscard]] std::pair do_get_mem_info( + cuda_stream_view stream) const override { return {allocation_limit_ - allocated_bytes_, allocation_limit_}; } @@ -220,5 +223,4 @@ limiting_resource_adaptor make_limiting_adaptor(Upstream* upstream, return limiting_resource_adaptor{upstream, allocation_limit}; } -} // namespace mr -} // namespace rmm +} // namespace rmm::mr From 72dbb4d8dcf7d932faf65ebc9a451dd896e0a042 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 13 Sep 2021 12:12:15 -0700 Subject: [PATCH 2/2] review feedback --- .../mr/device/limiting_resource_adaptor.hpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/include/rmm/mr/device/limiting_resource_adaptor.hpp b/include/rmm/mr/device/limiting_resource_adaptor.hpp index 6f00b937f..b83fe3911 100644 --- a/include/rmm/mr/device/limiting_resource_adaptor.hpp +++ b/include/rmm/mr/device/limiting_resource_adaptor.hpp @@ -129,18 +129,19 @@ class limiting_resource_adaptor final : public device_memory_resource { */ void* do_allocate(std::size_t bytes, cuda_stream_view stream) override { - void* p = nullptr; - - std::size_t proposed_size = rmm::detail::align_up(bytes, allocation_alignment_); - allocated_bytes_ += proposed_size; - if (allocated_bytes_ <= allocation_limit_) { - p = upstream_->allocate(bytes, stream); - } else { - allocated_bytes_ -= proposed_size; - RMM_FAIL("Exceeded memory limit", rmm::bad_alloc); + auto const proposed_size = rmm::detail::align_up(bytes, allocation_alignment_); + auto const old = allocated_bytes_.fetch_add(proposed_size); + if (old + proposed_size <= allocation_limit_) { + try { + return upstream_->allocate(bytes, stream); + } catch (...) { + allocated_bytes_ -= proposed_size; + throw; + } } - return p; + allocated_bytes_ -= proposed_size; + RMM_FAIL("Exceeded memory limit", rmm::bad_alloc); } /**