diff --git a/nvbench/cuda_stream.cuh b/nvbench/cuda_stream.cuh index 2c7536c..6329608 100644 --- a/nvbench/cuda_stream.cuh +++ b/nvbench/cuda_stream.cuh @@ -18,11 +18,14 @@ #pragma once -#include - #include +#include +#include +#include + #include +#include namespace nvbench { @@ -42,10 +45,18 @@ struct cuda_stream * Constructs a cuda_stream that owns a new stream, created with * `cudaStreamCreate`. */ - cuda_stream() - : m_stream{[]() { + cuda_stream(std::optional device) + : m_stream{[device]() { cudaStream_t s; - NVBENCH_CUDA_CALL(cudaStreamCreate(&s)); + if (device.has_value()) + { + nvbench::detail::device_scope scope_guard{device.value().get_id()}; + NVBENCH_CUDA_CALL(cudaStreamCreate(&s)); + } + else + { + NVBENCH_CUDA_CALL(cudaStreamCreate(&s)); + } return s; }(), stream_deleter{true}} diff --git a/nvbench/state.cuh b/nvbench/state.cuh index 53c7413..6a3afc9 100644 --- a/nvbench/state.cuh +++ b/nvbench/state.cuh @@ -261,7 +261,6 @@ private: std::optional device, std::size_t type_config_index); - nvbench::cuda_stream m_cuda_stream; std::reference_wrapper m_benchmark; nvbench::named_values m_axis_values; std::optional m_device; @@ -277,6 +276,8 @@ private: nvbench::float64_t m_skip_time; nvbench::float64_t m_timeout; + nvbench::cuda_stream m_cuda_stream; + // Deadlock protection. See blocking_kernel's class doc for details. nvbench::float64_t m_blocking_kernel_timeout{30.0}; diff --git a/nvbench/state.cxx b/nvbench/state.cxx index 3cf105c..f6f8993 100644 --- a/nvbench/state.cxx +++ b/nvbench/state.cxx @@ -41,6 +41,7 @@ state::state(const benchmark_base &bench) , m_max_noise{bench.get_max_noise()} , m_skip_time{bench.get_skip_time()} , m_timeout{bench.get_timeout()} + , m_cuda_stream{std::nullopt} {} state::state(const benchmark_base &bench, @@ -58,6 +59,7 @@ state::state(const benchmark_base &bench, , m_max_noise{bench.get_max_noise()} , m_skip_time{bench.get_skip_time()} , m_timeout{bench.get_timeout()} + , m_cuda_stream{m_device} {} nvbench::int64_t state::get_int64(const std::string &axis_name) const