Skip to content

Commit

Permalink
Set deviceCount conditionally to nbatch when dealing with many GPUs
Browse files Browse the repository at this point in the history
Limit the number of devices used to at most the length of the dimension we're trying to split.
  • Loading branch information
af-ayala authored and eng-flavio-teixeira committed Jul 8, 2024
1 parent 90db3d3 commit 183fe8e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 6 additions & 2 deletions clients/hipfft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,13 @@ class hipfft_params : public fft_params
int deviceCount = 0;
(void)hipGetDeviceCount(&deviceCount);

std::vector<int> GPUs(static_cast<size_t>(deviceCount));
// ensure that users request less than or equal to the total number of devices
if(multiGPU > deviceCount)
throw std::runtime_error("not enough devices for requested multi-gpu computation!");

std::vector<int> GPUs(static_cast<size_t>(multiGPU));
std::iota(GPUs.begin(), GPUs.end(), 0);
ret = hipfftXtSetGPUs(plan, deviceCount, GPUs.data());
ret = hipfftXtSetGPUs(plan, multiGPU, GPUs.data());

xt_worksize.resize(GPUs.size());
workbuffersize_ptr = xt_worksize.data();
Expand Down
3 changes: 2 additions & 1 deletion clients/tests/multi_device_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// THE SOFTWARE.

#include "../../shared/accuracy_test.h"
#include <algorithm>
#include <gtest/gtest.h>
#include <hip/hip_runtime_api.h>

Expand Down Expand Up @@ -68,7 +69,7 @@ std::vector<fft_params> param_generator_multi_gpu()
if(param.nbatch == 1 && param.placement == fft_placement_notinplace)
continue;

param_multi.multiGPU = deviceCount;
param_multi.multiGPU = std::min(static_cast<int>(param.nbatch), deviceCount);
all_params.emplace_back(std::move(param_multi));
}
};
Expand Down

0 comments on commit 183fe8e

Please sign in to comment.