Skip to content

Commit

Permalink
Give NowTask some teeth
Browse files Browse the repository at this point in the history
Summary:
Introducing an immovable `NowTask` isn't quite enough, since most `folly::coro` plumbing takes tasks by forwarding reference.

This means that dangerous patterns like this would still compile:

```
auto t = someNowTask();
// ...
co_await std::move(t);
```

To to go from "immovable task" to "task that can only be awaited in the statement that created it", we need to make sure that `folly::coro` plumbing takes `NowTask` only by value.

Unfortunately, there are other awaitables in the ecosystem, like `folly::coro::Baton`, which must be taken by reference. So, we have to accommodate both scenarios:
 - For awaitables not deriving from `MustAwaitImmediately` -- by reference (classic behavior)
 - `MustAwaitImmediately` awaitables like `NowTask` -- by value

In the absence of something like https://wg21.link/p2785, C++ does not support perfect forwarding for prvalues.

As far as I can tell, this forces me to branch most of the relevant functions on `is_must_await_immediately_t`, as you see above.

While this is ugly, I don't currently see a better way to get a safer task.

 ---

**NOTE:** I didn't fix up `co_withAsyncStack` because as far as I can tell, this is an implementation detail only called from `folly::coro` internals, where none of the callsites have potential for user-facing lifetime issues. I did audit them. The most annoying use-case was `*BarrierTask::await_transform`, but it seems OK too, since every callsite does `co_await co_viaIfAsync()`. Comparatively `await_transform` for `BlockingWaitTask` and `InlineTask*` are barely used.

Reviewed By: yfeldblum, ispeters

Differential Revision: D67883335

fbshipit-source-id: d1c61a484f58e9ba10ab29ee3eaa2c1ed6ced9fd
  • Loading branch information
Alexey Spiridonov authored and facebook-github-bot committed Feb 7, 2025
1 parent 1c55773 commit 9252c06
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 52 deletions.
47 changes: 45 additions & 2 deletions folly/coro/BlockingWait.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ struct blocking_wait_fn {
.get(frame));
}

template <typename SemiAwaitable>
template <
typename SemiAwaitable,
std::enable_if_t<!is_must_await_immediately_v<SemiAwaitable>, int> = 0>
FOLLY_NOINLINE auto operator()(
SemiAwaitable&& awaitable, folly::DrivableExecutor* executor) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
Expand All @@ -388,10 +390,33 @@ struct blocking_wait_fn {
static_cast<SemiAwaitable&&>(awaitable)))
.getVia(executor, frame));
}
template <
typename SemiAwaitable,
std::enable_if_t<is_must_await_immediately_v<SemiAwaitable>, int> = 0>
FOLLY_NOINLINE auto operator()(
SemiAwaitable awaitable, folly::DrivableExecutor* executor) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
folly::AsyncStackFrame frame;
frame.setReturnAddress();

folly::AsyncStackRoot stackRoot;
stackRoot.setNextRoot(folly::tryGetCurrentAsyncStackRoot());
stackRoot.setStackFrameContext();
stackRoot.setTopFrame(frame);

return static_cast<
std::add_rvalue_reference_t<semi_await_result_t<SemiAwaitable>>>(
detail::makeRefBlockingWaitTask(
folly::coro::co_viaIfAsync(
folly::getKeepAliveToken(executor),
std::move(awaitable).unsafeMoveMustAwaitImmediately()))
.getVia(executor, frame));
}

template <
typename SemiAwaitable,
std::enable_if_t<!is_awaitable_v<SemiAwaitable>, int> = 0>
std::enable_if_t<!is_awaitable_v<SemiAwaitable>, int> = 0,
std::enable_if_t<!is_must_await_immediately_v<SemiAwaitable>, int> = 0>
auto operator()(SemiAwaitable&& awaitable) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
std::exception_ptr eptr;
Expand All @@ -405,6 +430,24 @@ struct blocking_wait_fn {
}
std::rethrow_exception(eptr);
}
template <
typename SemiAwaitable,
std::enable_if_t<!is_awaitable_v<SemiAwaitable>, int> = 0,
std::enable_if_t<is_must_await_immediately_v<SemiAwaitable>, int> = 0>
auto operator()(SemiAwaitable awaitable) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
std::exception_ptr eptr;
{
detail::BlockingWaitExecutor executor;
try {
return operator()(
std::move(awaitable).unsafeMoveMustAwaitImmediately(), &executor);
} catch (...) {
eptr = current_exception();
}
}
std::rethrow_exception(eptr);
}
};
inline constexpr blocking_wait_fn blocking_wait{};
static constexpr blocking_wait_fn const& blockingWait =
Expand Down
24 changes: 24 additions & 0 deletions folly/coro/Coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,30 @@ class ExtendedCoroutineHandle {
ExtendedCoroutinePromise* extended_{nullptr};
};

// When the `folly::coro` API sees a (semi)awaitable that derives from this
// (tested via `is_must_await_immediately_t` below), it will take the
// awaitable by-value instead of by-forwarding-reference. This supports
// immovable tasks that can only be awaited in the same statement that
// created them, like `NowTask`, `ClosureTask`, `MemberTask`, etc.
//
// To speak this protocol, your awaitable must do two things:
// - Derive from `private MustAwaitImmediately`
// - Implement `YourAwaitable unsafeMoveMustAwaitImmediately() &&`, which
// moves `*this` into a new prvalue of your type.
//
// WARNING: If we see usage of `unsafeMoveMustAwaitImmediately()` outside of
// `folly::coro`, we will convert it to a passkey pattern, and break you.
// Think of it as an explicit, private-to-`folly` move ctor.
//
// Caveat: If you encounter a public `folly::coro` API that is not
// `MustAwaitImmediately`-aware, and simply take the awaitable by `&&`,
// please report it, and/or branch it. See `NothrowAwaitable` e.g.
struct MustAwaitImmediately : private NonCopyableNonMovable {};

template <typename Awaitable>
inline constexpr bool is_must_await_immediately_v = // private inheritance ok!
std::is_base_of_v<MustAwaitImmediately, Awaitable>;

} // namespace folly::coro

#endif
30 changes: 28 additions & 2 deletions folly/coro/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ class TaskPromiseBase {

FinalAwaiter final_suspend() noexcept { return {}; }

template <typename Awaitable>
template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(Awaitable&& awaitable) {
bypassExceptionThrowing_ =
bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
Expand All @@ -153,12 +155,36 @@ class TaskPromiseBase {
folly::coro::co_withCancellation(
cancelToken_, static_cast<Awaitable&&>(awaitable))));
}
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(Awaitable awaitable) {
bypassExceptionThrowing_ =
bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
? BypassExceptionThrowing::ACTIVE
: BypassExceptionThrowing::INACTIVE;

template <typename Awaitable>
return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync(
executor_.get_alias(),
folly::coro::co_withCancellation(
cancelToken_,
std::move(awaitable).unsafeMoveMustAwaitImmediately())));
}

template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(NothrowAwaitable<Awaitable>&& awaitable) {
bypassExceptionThrowing_ = BypassExceptionThrowing::REQUESTED;
return await_transform(awaitable.unwrap());
}
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(NothrowAwaitable<Awaitable> awaitable) {
bypassExceptionThrowing_ = BypassExceptionThrowing::REQUESTED;
return await_transform(awaitable.unwrap().unsafeMoveMustAwaitImmediately());
}

auto await_transform(co_current_executor_t) noexcept {
return ready_awaitable<folly::Executor*>{executor_.get()};
Expand Down
14 changes: 12 additions & 2 deletions folly/coro/TaskWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,18 @@ class TaskPromiseWrapperBase {
auto initial_suspend() noexcept { return promise_.initial_suspend(); }
auto final_suspend() noexcept { return promise_.final_suspend(); }

auto await_transform(auto&& what) {
return promise_.await_transform(std::forward<decltype(what)>(what));
template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(Awaitable&& what) {
return promise_.await_transform(std::forward<Awaitable>(what));
}
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
auto await_transform(Awaitable what) {
return promise_.await_transform(
std::move(what).unsafeMoveMustAwaitImmediately());
}

auto yield_value(auto&& v)
Expand Down
106 changes: 98 additions & 8 deletions folly/coro/ViaIfAsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,29 @@ template <
typename Awaitable,
std::enable_if_t<
is_awaitable_v<Awaitable> && !HasViaIfAsyncMethod<Awaitable>::value,
int> = 0>
int> = 0,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
-> ViaIfAsyncAwaitable<Awaitable> {
return ViaIfAsyncAwaitable<Awaitable>{
std::move(executor), static_cast<Awaitable&&>(awaitable)};
}
template <
typename Awaitable,
std::enable_if_t<
is_awaitable_v<Awaitable> && !HasViaIfAsyncMethod<Awaitable>::value,
int> = 0,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable awaitable)
-> ViaIfAsyncAwaitable<Awaitable> {
return ViaIfAsyncAwaitable<Awaitable>{
std::move(executor), std::move(awaitable)};
}

struct ViaIfAsyncFunction {
template <typename Awaitable>
template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
auto operator()(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
const noexcept(noexcept(co_viaIfAsync(
std::move(executor), static_cast<Awaitable&&>(awaitable))))
Expand All @@ -510,7 +524,21 @@ struct ViaIfAsyncFunction {
return co_viaIfAsync(
std::move(executor), static_cast<Awaitable&&>(awaitable));
}
};
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
auto operator()(folly::Executor::KeepAlive<> executor, Awaitable awaitable)
const noexcept(noexcept(co_viaIfAsync(
std::move(executor),
std::move(awaitable).unsafeMoveMustAwaitImmediately())))
-> decltype(co_viaIfAsync(
std::move(executor),
std::move(awaitable).unsafeMoveMustAwaitImmediately())) {
return co_viaIfAsync(
std::move(executor),
std::move(awaitable).unsafeMoveMustAwaitImmediately());
}
}; // namespace adl

} // namespace adl
} // namespace detail
Expand Down Expand Up @@ -583,12 +611,24 @@ class TryAwaiter {
* co_withCancellation while keeping the corresponding awaitable on the outside
*/
template <template <typename T> typename Derived, typename T>
class CommutativeWrapperAwaitable {
class CommutativeWrapperAwaitable
: private std::conditional_t<
std::is_base_of_v<MustAwaitImmediately, T>,
MustAwaitImmediately,
Unit> {
public:
template <typename T2>
template <
typename T2,
std::enable_if_t<!is_must_await_immediately_v<T2>, int> = 0>
explicit CommutativeWrapperAwaitable(T2&& awaitable) noexcept(
std::is_nothrow_constructible_v<T, T2>)
: inner_(static_cast<T2&&>(awaitable)) {}
template <
typename T2,
std::enable_if_t<is_must_await_immediately_v<T2>, int> = 0>
explicit CommutativeWrapperAwaitable(T2 awaitable) noexcept(
std::is_nothrow_constructible_v<T, T2>)
: inner_(std::move(awaitable).unsafeMoveMustAwaitImmediately()) {}

template <typename Factory>
explicit CommutativeWrapperAwaitable(std::in_place_t, Factory&& factory)
Expand Down Expand Up @@ -623,6 +663,7 @@ class CommutativeWrapperAwaitable {

template <
typename T2 = T,
std::enable_if_t<!is_must_await_immediately_v<T2>, int> = 0,
typename Result = decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor::KeepAlive<>>(), std::declval<T2>()))>
friend Derived<Result> co_viaIfAsync(
Expand All @@ -639,6 +680,31 @@ class CommutativeWrapperAwaitable {
std::move(executor), static_cast<T&&>(awaitable.inner_));
}};
}
template <
typename T2 = T,
std::enable_if_t<is_must_await_immediately_v<T2>, int> = 0,
typename Result = decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor::KeepAlive<>>(),
std::declval<T2>().unsafeMoveMustAwaitImmediately()))>
friend Derived<Result>
co_viaIfAsync(folly::Executor::KeepAlive<> executor, Derived<T> awaitable) noexcept(
noexcept(folly::coro::co_viaIfAsync(
std::declval<folly::Executor::KeepAlive<>>(),
std::declval<T2>().unsafeMoveMustAwaitImmediately()))) {
return Derived<Result>{
std::in_place, [&]() {
return folly::coro::co_viaIfAsync(
std::move(executor),
std::move(awaitable.inner_).unsafeMoveMustAwaitImmediately());
}};
}

template <
typename T2 = T,
std::enable_if_t<is_must_await_immediately_v<T2>, int> = 0>
auto unsafeMoveMustAwaitImmediately() && {
return Derived<T>{std::move(inner_).unsafeMoveMustAwaitImmediately()};
}

protected:
T inner_;
Expand All @@ -657,20 +723,34 @@ class [[FOLLY_ATTR_CLANG_CORO_AWAIT_ELIDABLE]] TryAwaitable
std::is_same_v<remove_cvref_t<Self>, TryAwaitable>,
int> = 0,
typename T2 = like_t<Self, T>,
std::enable_if_t<is_awaitable_v<T2>, int> = 0>
std::enable_if_t<is_awaitable_v<T2>, int> = 0,
typename T3 = T,
// Future: If you have a compile error where this isn't satisfied, add
// a `true` branch calling `unsafeMoveMustAwaitImmediately()`.
std::enable_if_t<!is_must_await_immediately_v<T3>, int> = 0>
friend TryAwaiter<T2> operator co_await(Self && self) {
return TryAwaiter<T2>{static_cast<Self&&>(self).inner_};
}
};

} // namespace detail

template <typename Awaitable>
template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
detail::TryAwaitable<remove_cvref_t<Awaitable>> co_awaitTry(
Awaitable&& awaitable) {
return detail::TryAwaitable<remove_cvref_t<Awaitable>>{
static_cast<Awaitable&&>(awaitable)};
}
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
detail::TryAwaitable<remove_cvref_t<Awaitable>> co_awaitTry(
Awaitable awaitable) {
return detail::TryAwaitable<remove_cvref_t<Awaitable>>{
std::move(awaitable).unsafeMoveMustAwaitImmediately()};
}

template <typename T>
using semi_await_try_result_t =
Expand All @@ -692,12 +772,22 @@ class [[FOLLY_ATTR_CLANG_CORO_AWAIT_ELIDABLE]] NothrowAwaitable

} // namespace detail

template <typename Awaitable>
template <
typename Awaitable,
std::enable_if_t<!is_must_await_immediately_v<Awaitable>, int> = 0>
detail::NothrowAwaitable<remove_cvref_t<Awaitable>> co_nothrow(
[[FOLLY_ATTR_CLANG_CORO_AWAIT_ELIDABLE_ARGUMENT]] Awaitable&& awaitable) {
return detail::NothrowAwaitable<remove_cvref_t<Awaitable>>{
static_cast<Awaitable&&>(awaitable)};
}
template <
typename Awaitable,
std::enable_if_t<is_must_await_immediately_v<Awaitable>, int> = 0>
detail::NothrowAwaitable<remove_cvref_t<Awaitable>> co_nothrow(
[[FOLLY_ATTR_CLANG_CORO_AWAIT_ELIDABLE_ARGUMENT]] Awaitable awaitable) {
return detail::NothrowAwaitable<remove_cvref_t<Awaitable>>{
std::move(awaitable).unsafeMoveMustAwaitImmediately()};
}

} // namespace coro
} // namespace folly
Expand Down
Loading

0 comments on commit 9252c06

Please sign in to comment.