Skip to content

Refactor AsyncWrapper to make it safer #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 80 additions & 59 deletions include/cpr/async_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
#include <future>
#include <memory>

#include "cpr/response.h"

namespace cpr {
enum class [[nodiscard]] CancellationResult { failure, success, invalid_operation };
enum class [[nodiscard]] CancellationResult : uint8_t { failure, success, invalid_operation };

/**
* A class template intended to wrap results of async operations (instances of std::future<T>)
Expand All @@ -17,15 +15,24 @@ enum class [[nodiscard]] CancellationResult { failure, success, invalid_operatio
* The RAII semantics are the same as std::future<T> - moveable, not copyable.
*/
template <typename T, bool isCancellable = false>
class AsyncWrapper {
class AsyncWrapper;

template <typename T>
class AsyncWrapper<T, false> {
private:
friend class AsyncWrapper<T, true>;
std::future<T> future;
std::shared_ptr<std::atomic_bool> is_cancelled;

void throw_if_invalid(const char* error) const {
if (!future.valid()) {
throw std::logic_error{error};
}
}

public:
// Constructors
AsyncWrapper() = default;
explicit AsyncWrapper(std::future<T>&& f) : future{std::move(f)} {}
AsyncWrapper(std::future<T>&& f, std::shared_ptr<std::atomic_bool>&& cancelledState) : future{std::move(f)}, is_cancelled{std::move(cancelledState)} {}

// Copy Semantics
AsyncWrapper(const AsyncWrapper&) = delete;
Expand All @@ -36,94 +43,108 @@ class AsyncWrapper {
AsyncWrapper& operator=(AsyncWrapper&&) noexcept = default;

// Destructor
~AsyncWrapper() {
if constexpr (isCancellable) {
if (is_cancelled) {
is_cancelled->store(true);
}
}
}
~AsyncWrapper() = default;

// These methods replicate the behaviour of std::future<T>
[[nodiscard]] T get() {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::get on a cancelled request!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::get when the associated future instance is invalid!"};
}
throw_if_invalid("Calling AsyncWrapper::get when the associated future instance is invalid!");
return future.get();
}

[[nodiscard]] bool valid() const noexcept {
if constexpr (isCancellable) {
return !is_cancelled->load() && future.valid();
} else {
return future.valid();
}
return future.valid();
}

void wait() const {
if constexpr (isCancellable) {
if (is_cancelled->load()) {
throw std::logic_error{"Calling AsyncWrapper::wait when the associated future is invalid or cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
throw_if_invalid("Calling AsyncWrapper::wait when the associated future is invalid!");
future.wait();
}

template <class Rep, class Period>
std::future_status wait_for(const std::chrono::duration<Rep, Period>& timeout_duration) const {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::wait_for when the associated future is cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
throw_if_invalid("Calling AsyncWrapper::wait_for when the associated future is invalid!");
return future.wait_for(timeout_duration);
}

template <class Clock, class Duration>
std::future_status wait_until(const std::chrono::time_point<Clock, Duration>& timeout_time) const {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
throw_if_invalid("Calling AsyncWrapper::wait_until when the associated future is invalid!");
return future.wait_until(timeout_time);
}

std::shared_future<T> share() noexcept {
return future.share();
}
};

template <typename T>
class AsyncWrapper<T, true> : public AsyncWrapper<T, false> {
private:
using base = AsyncWrapper<T, false>;
std::shared_ptr<std::atomic_bool> is_cancelled;

void throw_if_cancelled(const char* error) const {
if (is_cancelled->load()) {
throw std::logic_error{error};
}
}

public:
// Constructors
AsyncWrapper(std::future<T>&& f, std::shared_ptr<std::atomic_bool>&& cancelledState) : base{std::move(f)}, is_cancelled{std::move(cancelledState)} {}

// Copy Semantics
AsyncWrapper(const AsyncWrapper&) = delete;
AsyncWrapper& operator=(const AsyncWrapper&) = delete;

// Move Semantics
AsyncWrapper(AsyncWrapper&&) noexcept = default;
AsyncWrapper& operator=(AsyncWrapper&&) noexcept = default;

// Destructor
~AsyncWrapper() {
if (is_cancelled) {
is_cancelled->store(true);
}
}

[[nodiscard]] T get() {
throw_if_cancelled("Calling AsyncWrapper::get on a cancelled request!");
return base::get();
}

[[nodiscard]] bool valid() const noexcept {
return !is_cancelled->load() && base::future.valid();
}

void wait() const {
throw_if_cancelled("Calling AsyncWrapper::wait when the associated future is invalid or cancelled!");
base::wait();
}

template <class Rep, class Period>
std::future_status wait_for(const std::chrono::duration<Rep, Period>& timeout_duration) const {
throw_if_cancelled("Calling AsyncWrapper::wait_for when the associated future is cancelled!");
return base::wait_for(timeout_duration);
}

template <class Clock, class Duration>
std::future_status wait_until(const std::chrono::time_point<Clock, Duration>& timeout_time) const {
throw_if_cancelled("Calling AsyncWrapper::wait_until when the associated future is cancelled!");
return base::wait_until(timeout_time);
}

// Cancellation-related methods
CancellationResult Cancel() {
if constexpr (!isCancellable) {
return CancellationResult::invalid_operation;
}
if (!future.valid() || is_cancelled->load()) {
if (!base::future.valid() || is_cancelled->load()) {
return CancellationResult::invalid_operation;
}
is_cancelled->store(true);
return CancellationResult::success;
}

[[nodiscard]] bool IsCancelled() const {
if constexpr (isCancellable) {
return is_cancelled->load();
} else {
return false;
}
return is_cancelled->load();
}
};

Expand Down
3 changes: 0 additions & 3 deletions test/multiasync_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ TEST(AsyncWrapperNonCancellableTests, TestExceptionsNoSharedState) {

// We create an AsyncWrapper for a future without a shared state (default-initialized)
AsyncWrapper test_wrapper{std::future<std::string>{}};


ASSERT_FALSE(test_wrapper.valid());
ASSERT_FALSE(test_wrapper.IsCancelled());

// Trying to get or wait for a future that doesn't have a shared state should result to an exception
// It should be noted that there is a divergence from std::future behavior here: calling wait* on the original std::future is undefined behaviour, according to cppreference.com . We find it preferrable to throw an exception.
Expand Down
Loading