diff --git a/include/rmm/exec_policy.hpp b/include/rmm/exec_policy.hpp index 98cd91cd4..fdd22249f 100644 --- a/include/rmm/exec_policy.hpp +++ b/include/rmm/exec_policy.hpp @@ -25,6 +25,7 @@ #include #include +#include namespace rmm { @@ -46,4 +47,33 @@ class exec_policy : public thrust_exec_policy_t { } }; +#if THRUST_VERSION >= 101600 + +using thrust_exec_policy_nosync_t = + thrust::detail::execute_with_allocator, + thrust::cuda_cub::execute_on_stream_nosync_base>; +/** + * @brief Helper class usable as a Thrust CUDA execution policy + * that uses RMM for temporary memory allocation on the specified stream + * and which allows the Thrust backend to skip stream synchronizations that + * are not required for correctness. + */ +class exec_policy_nosync : public thrust_exec_policy_nosync_t { + public: + explicit exec_policy_nosync( + cuda_stream_view stream = cuda_stream_default, + rmm::mr::device_memory_resource* mr = mr::get_current_device_resource()) + : thrust_exec_policy_nosync_t( + thrust::cuda::par_nosync(rmm::mr::thrust_allocator(stream, mr)).on(stream.value())) + { + } +}; + +#else + +using thrust_exec_policy_nosync_t = thrust_exec_policy_t; +using exec_policy_nosync = exec_policy; + +#endif + } // namespace rmm