diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index e59a66037..57f65464c 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -25,51 +25,71 @@ struct DeviceMulticastPointerDeviceHandle { size_t bufferSize; #if defined(MSCCLPP_DEVICE_CUDA) - template + template MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) { - static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value && std::is_same::value) { asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(ptr) : "memory"); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value && std::is_same::value) { asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(ptr) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "l"(ptr) + : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm("multimem.ld_reduce.relaxed.sys.global.add.f16 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory"); } else { static_assert(dependentFalse, "Not supported type"); } }; - template + template MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) { - static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value && std::is_same::value) { asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value && std::is_same::value) { asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y) + : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.st.relaxed.sys.global.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); } else { static_assert(dependentFalse, "Not supported type"); } }; - template + template MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) { - static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value && std::is_same::value) { asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value && std::is_same::value) { asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.red.relaxed.sys.global.add.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y) + : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.red.relaxed.sys.global.add.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); + } else if constexpr (std::is_same::value && std::is_same::value) { + asm volatile("multimem.red.relaxed.sys.global.add.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); } else { static_assert(dependentFalse, "Not supported type"); } diff --git a/src/nvls.cc b/src/nvls.cc index 7e3f2a41d..5504e3b25 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -232,7 +232,8 @@ class NvlsConnection::Impl { size_t getMinMcGran() { throw notSupportedError; } private: - Error notSupportedError = Error("NVLS is not supported on this CUDA version", ErrorCode::InvalidUsage); + Error notSupportedError = + Error("NVLS is not supported on this CUDA version (< 12.1) or kernel version (< 5.6.0)", ErrorCode::InvalidUsage); }; #endif // !(USE_NVLS)