Skip to content

Commit

Permalink
Merge pull request #99 from eng-flavio-teixeira/fix_multi_gpu_tests_e…
Browse files Browse the repository at this point in the history
…rror_6.2

Fix multi-gpu test errors
  • Loading branch information
mamaydeo authored Jul 19, 2024
2 parents 90db3d3 + 183fe8e commit baee8f8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 6 additions & 2 deletions clients/hipfft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,13 @@ class hipfft_params : public fft_params
int deviceCount = 0;
(void)hipGetDeviceCount(&deviceCount);

std::vector<int> GPUs(static_cast<size_t>(deviceCount));
// ensure that users request less than or equal to the total number of devices
if(multiGPU > deviceCount)
throw std::runtime_error("not enough devices for requested multi-gpu computation!");

std::vector<int> GPUs(static_cast<size_t>(multiGPU));
std::iota(GPUs.begin(), GPUs.end(), 0);
ret = hipfftXtSetGPUs(plan, deviceCount, GPUs.data());
ret = hipfftXtSetGPUs(plan, multiGPU, GPUs.data());

xt_worksize.resize(GPUs.size());
workbuffersize_ptr = xt_worksize.data();
Expand Down
3 changes: 2 additions & 1 deletion clients/tests/multi_device_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// THE SOFTWARE.

#include "../../shared/accuracy_test.h"
#include <algorithm>
#include <gtest/gtest.h>
#include <hip/hip_runtime_api.h>

Expand Down Expand Up @@ -68,7 +69,7 @@ std::vector<fft_params> param_generator_multi_gpu()
if(param.nbatch == 1 && param.placement == fft_placement_notinplace)
continue;

param_multi.multiGPU = deviceCount;
param_multi.multiGPU = std::min(static_cast<int>(param.nbatch), deviceCount);
all_params.emplace_back(std::move(param_multi));
}
};
Expand Down

0 comments on commit baee8f8

Please sign in to comment.