Skip to content

Commit

Permalink
misc bug fixes for cudax ustdex
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed Feb 6, 2025
1 parent 044cabf commit 0c11c32
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
3 changes: 3 additions & 0 deletions cudax/include/cuda/experimental/__async/sender/env.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT prop
prop& operator=(const prop&) = delete;
};

template <class _Query, class _Value>
prop(_Query, _Value) -> prop<_Query, _Value>;

template <class... _Envs>
struct _CCCL_TYPE_VISIBILITY_DEFAULT env
{
Expand Down
2 changes: 2 additions & 0 deletions cudax/include/cuda/experimental/__async/sender/meta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct _IN_ALGORITHM;

struct _WHAT;

struct _WHY;

struct _WITH_FUNCTION;

struct _WITH_SENDER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ struct __rcvr_with_env_t<_Rcvr*, _Env>
_Rcvr* __rcvr_;
_Env __env_;
};

template <class _Rcvr, class _Env>
__rcvr_with_env_t(_Rcvr, _Env) -> __rcvr_with_env_t<_Rcvr, _Env>;

} // namespace cuda::experimental::__async

#include <cuda/experimental/__async/sender/epilogue.cuh>
Expand Down
47 changes: 33 additions & 14 deletions cudax/include/cuda/experimental/__async/sender/sync_wait.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
// run_loop isn't supported on-device yet, so neither can sync_wait be.
#if !defined(__CUDA_ARCH__)

# include <cuda/std/__type_traits/type_identity.h>
# include <cuda/std/optional>
# include <cuda/std/tuple>

# include <cuda/experimental/__async/sender/exception.cuh>
# include <cuda/experimental/__async/sender/meta.cuh>
# include <cuda/experimental/__async/sender/run_loop.cuh>
# include <cuda/experimental/__async/sender/utility.cuh>
# include <cuda/experimental/__async/sender/write_env.cuh>

# include <system_error>

Expand Down Expand Up @@ -115,26 +117,38 @@ private:
}
};

using __values_t = value_types_of_t<_Sndr, __rcvr_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>;
using __completions_t = completion_signatures_of_t<_Sndr, __rcvr_t>;

struct __on_success
{
using type = __value_types<__completions_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>;
};

using __on_error = _CUDA_VSTD::type_identity<_CUDA_VSTD::tuple<__completions_t>>;

using __values_t =
typename _CUDA_VSTD::_If<__is_completion_signatures<__completions_t>, __on_success, __on_error>::type;

_CUDA_VSTD::optional<__values_t>* __values_;
::std::exception_ptr __eptr_;
run_loop __loop_;
};

struct __invalid_sync_wait
template <class _Type>
struct __always_false : _CUDA_VSTD::false_type
{};

template <class _Diagnostic>
struct __bad_sync_wait
{
const __invalid_sync_wait& value() const
{
return *this;
}
static_assert(__always_false<_Diagnostic>(),
"sync_wait cannot compute the completions of the sender passed to it.");
static __bad_sync_wait __result();

const __invalid_sync_wait& operator*() const
{
return *this;
}
const __bad_sync_wait& value() const;
const __bad_sync_wait& operator*() const;

int __i_;
int i{}; // so that structured bindings kinda work
};

public:
Expand Down Expand Up @@ -168,12 +182,11 @@ public:
{
using __rcvr_t = typename __state_t<_Sndr>::__rcvr_t;
using __values_t = typename __state_t<_Sndr>::__values_t;
using __completions = completion_signatures_of_t<_Sndr, __rcvr_t>;
static_assert(__is_completion_signatures<__completions>);
using __completions = typename __state_t<_Sndr>::__completions_t;

if constexpr (!__is_completion_signatures<__completions>)
{
return __invalid_sync_wait{0};
return __bad_sync_wait<__completions>::__result();
}
else
{
Expand All @@ -196,6 +209,12 @@ public:
return __result; // uses NRVO to "return" the result
}
}

template <class _Sndr, class _Env>
auto operator()(_Sndr&& __sndr, _Env&& __env) const
{
return (*this)(__async::write_env(static_cast<_Sndr&&>(__sndr), static_cast<_Env&&>(__env)));
}
};

_CCCL_GLOBAL_CONSTANT sync_wait_t sync_wait{};
Expand Down
10 changes: 4 additions & 6 deletions cudax/include/cuda/experimental/__async/sender/utility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__tuple_dir/ignore.h>
#include <cuda/std/__type_traits/decay.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/initializer_list>

Expand All @@ -34,11 +36,7 @@ namespace cuda::experimental::__async
{
_CCCL_GLOBAL_CONSTANT size_t __npos = static_cast<size_t>(-1);

struct __ignore
{
template <class... _As>
_CUDAX_API constexpr __ignore(_As&&...) noexcept {};
};
using __ignore _CCCL_NODEBUG_ALIAS = _CUDA_VSTD::__ignore_t; // NOLINT: misc-unused-using-decls

using _CUDA_VSTD::__undefined; // NOLINT: misc-unused-using-decls

Expand Down Expand Up @@ -116,7 +114,7 @@ _CUDAX_API constexpr void __swap(_Ty& __left, _Ty& __right) noexcept
}

template <class _Ty>
_CUDAX_API constexpr _Ty __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>)
_CUDAX_API constexpr _CUDA_VSTD::decay_t<_Ty> __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>)
{
return static_cast<_Ty&&>(__ty);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private:
connect_result_t<_Sndr, __rcvr_with_env_t<_Rcvr, _Env>*> __opstate_;

_CUDAX_API explicit __opstate_t(_Sndr&& __sndr, _Env __env, _Rcvr __rcvr)
: __env_rcvr_(static_cast<_Env&&>(__env), static_cast<_Rcvr&&>(__rcvr))
: __env_rcvr_{static_cast<_Rcvr&&>(__rcvr), static_cast<_Env&&>(__env)}
, __opstate_(__async::connect(static_cast<_Sndr&&>(__sndr), &__env_rcvr_))
{}

Expand Down

0 comments on commit 0c11c32

Please sign in to comment.