diff --git a/clients/hipfft_params.h b/clients/hipfft_params.h index 2002c41..1a6980f 100644 --- a/clients/hipfft_params.h +++ b/clients/hipfft_params.h @@ -959,9 +959,13 @@ class hipfft_params : public fft_params int deviceCount = 0; (void)hipGetDeviceCount(&deviceCount); - std::vector GPUs(static_cast(deviceCount)); + // ensure that users request less than or equal to the total number of devices + if(multiGPU > deviceCount) + throw std::runtime_error("not enough devices for requested multi-gpu computation!"); + + std::vector GPUs(static_cast(multiGPU)); std::iota(GPUs.begin(), GPUs.end(), 0); - ret = hipfftXtSetGPUs(plan, deviceCount, GPUs.data()); + ret = hipfftXtSetGPUs(plan, multiGPU, GPUs.data()); xt_worksize.resize(GPUs.size()); workbuffersize_ptr = xt_worksize.data(); diff --git a/clients/tests/multi_device_test.cpp b/clients/tests/multi_device_test.cpp index f7407c4..5e766d8 100644 --- a/clients/tests/multi_device_test.cpp +++ b/clients/tests/multi_device_test.cpp @@ -19,6 +19,7 @@ // THE SOFTWARE. #include "../../shared/accuracy_test.h" +#include #include #include @@ -68,7 +69,7 @@ std::vector param_generator_multi_gpu() if(param.nbatch == 1 && param.placement == fft_placement_notinplace) continue; - param_multi.multiGPU = deviceCount; + param_multi.multiGPU = std::min(static_cast(param.nbatch), deviceCount); all_params.emplace_back(std::move(param_multi)); } };