diff --git a/HeterogeneousCore/ROCmUtilities/interface/requireDevices.h b/HeterogeneousCore/ROCmUtilities/interface/requireDevices.h new file mode 100644 index 0000000000000..d09e262f861ec --- /dev/null +++ b/HeterogeneousCore/ROCmUtilities/interface/requireDevices.h @@ -0,0 +1,19 @@ +#ifndef HeterogeneousCore_ROCmUtilities_interface_requireDevices_h +#define HeterogeneousCore_ROCmUtilities_interface_requireDevices_h + +/** + * These functions are meant to be called only from unit tests. + */ +namespace cms { + namespace rocmtest { + + /// In presence of ROCm devices, return true; otherwise print message and return false + bool testDevices(); + + /// Print message and exit if there are no ROCm devices + void requireDevices(); + + } // namespace rocmtest +} // namespace cms + +#endif // HeterogeneousCore_ROCmUtilities_interface_requireDevices_h diff --git a/HeterogeneousCore/ROCmUtilities/src/requireDevices.cc b/HeterogeneousCore/ROCmUtilities/src/requireDevices.cc new file mode 100644 index 0000000000000..b62d48d383d07 --- /dev/null +++ b/HeterogeneousCore/ROCmUtilities/src/requireDevices.cc @@ -0,0 +1,30 @@ +#include +#include + +#include + +#include "HeterogeneousCore/ROCmUtilities/interface/requireDevices.h" + +namespace cms::rocmtest { + + bool testDevices() { + int devices = 0; + auto status = hipGetDeviceCount(&devices); + if (status != hipSuccess) { + std::cerr << "Failed to initialise the ROCm runtime, the test will be skipped.\n"; + return false; + } + if (devices == 0) { + std::cerr << "No ROCm devices available, the test will be skipped.\n"; + return false; + } + return true; + } + + void requireDevices() { + if (not testDevices()) { + exit(EXIT_SUCCESS); + } + } + +} // namespace cms::rocmtest diff --git a/HeterogeneousCore/ROCmUtilities/test/BuildFile.xml b/HeterogeneousCore/ROCmUtilities/test/BuildFile.xml index b07a46959b8c9..516b14b94cbd9 100644 --- a/HeterogeneousCore/ROCmUtilities/test/BuildFile.xml +++ b/HeterogeneousCore/ROCmUtilities/test/BuildFile.xml @@ -1,7 +1,13 @@ - + + + + + + + diff --git a/HeterogeneousCore/ROCmUtilities/test/testRequireROCmDevices.cpp b/HeterogeneousCore/ROCmUtilities/test/testRequireROCmDevices.cpp new file mode 100644 index 0000000000000..edc2eff9672ea --- /dev/null +++ b/HeterogeneousCore/ROCmUtilities/test/testRequireROCmDevices.cpp @@ -0,0 +1,21 @@ +// Catch2 headers +#define CATCH_CONFIG_MAIN +#include + +// ROCm headers +#include + +// CMSSW headers +#include "HeterogeneousCore/ROCmUtilities/interface/hipCheck.h" +#include "HeterogeneousCore/ROCmUtilities/interface/requireDevices.h" + +TEST_CASE("HeterogeneousCore/ROCmUtilities testRequireROCmDevices", "[testRequireROCmDevices]") { + SECTION("Test requireDevices()") { + cms::rocmtest::requireDevices(); + + int devices = 0; + hipCheck(hipGetDeviceCount(&devices)); + + REQUIRE(devices > 0); + } +}