diff --git a/source/adapters/cuda/device.cpp b/source/adapters/cuda/device.cpp index bbaaa27cdb..9c8a0c807c 100644 --- a/source/adapters/cuda/device.cpp +++ b/source/adapters/cuda/device.cpp @@ -57,12 +57,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(4318u); } case UR_DEVICE_INFO_MAX_COMPUTE_UNITS: { - int ComputeUnits = 0; - UR_CHECK_ERROR(cuDeviceGetAttribute( - &ComputeUnits, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, - hDevice->get())); - detail::ur::assertion(ComputeUnits >= 0); - return ReturnValue(static_cast(ComputeUnits)); + return ReturnValue(hDevice->getNumComputeUnits()); } case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: { return ReturnValue(MaxWorkItemDimensions); diff --git a/source/adapters/cuda/device.hpp b/source/adapters/cuda/device.hpp index 0a40329026..3654f2bb36 100644 --- a/source/adapters/cuda/device.hpp +++ b/source/adapters/cuda/device.hpp @@ -32,6 +32,7 @@ struct ur_device_handle_t_ { int MaxCapacityLocalMem{0}; int MaxChosenLocalMem{0}; bool MaxLocalMemSizeChosen{false}; + uint32_t NumComputeUnits{0}; public: ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase, @@ -54,6 +55,10 @@ struct ur_device_handle_t_ { sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr)); + UR_CHECK_ERROR(cuDeviceGetAttribute( + reinterpret_cast(&NumComputeUnits), + CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, cuDevice)); + // Set local mem max size if env var is present static const char *LocalMemSizePtrUR = std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE"); @@ -107,6 +112,8 @@ struct ur_device_handle_t_ { int getMaxChosenLocalMem() const noexcept { return MaxChosenLocalMem; }; bool maxLocalMemSizeChosen() { return MaxLocalMemSizeChosen; }; + + uint32_t getNumComputeUnits() const noexcept { return NumComputeUnits; }; }; int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute); diff --git a/source/adapters/cuda/kernel.cpp b/source/adapters/cuda/kernel.cpp index d43bd046dc..2061893744 100644 --- a/source/adapters/cuda/kernel.cpp +++ b/source/adapters/cuda/kernel.cpp @@ -167,10 +167,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, size_t localWorkSize, size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) { - (void)hKernel; - (void)localWorkSize; - (void)dynamicSharedMemorySize; - *pGroupCountRet = 1; + UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL); + + // We need to set the active current device for this kernel explicitly here, + // because the occupancy querying API does not take device parameter. + ur_device_handle_t Device = hKernel->getProgram()->getDevice(); + ScopedContext Active(Device); + try { + // We need to calculate max num of work-groups using per-device semantics. + + int MaxNumActiveGroupsPerCU{0}; + UR_CHECK_ERROR(cuOccupancyMaxActiveBlocksPerMultiprocessor( + &MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize, + dynamicSharedMemorySize)); + detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0); + // Handle the case where we can't have all SMs active with at least 1 group + // per SM. In that case, the device is still able to run 1 work-group, hence + // we will manually check if it is possible with the available HW resources. + if (MaxNumActiveGroupsPerCU == 0) { + size_t MaxWorkGroupSize{}; + urKernelGetGroupInfo( + hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE, + sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr); + size_t MaxLocalSizeBytes{}; + urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE, + sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr); + if (localWorkSize > MaxWorkGroupSize || + dynamicSharedMemorySize > MaxLocalSizeBytes || + hasExceededMaxRegistersPerBlock(Device, hKernel, localWorkSize)) + *pGroupCountRet = 0; + else + *pGroupCountRet = 1; + } else { + // Multiply by the number of SMs (CUs = compute units) on the device in + // order to retreive the total number of groups/blocks that can be + // launched. + *pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU; + } + } catch (ur_result_t Err) { + return Err; + } return UR_RESULT_SUCCESS; }