Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different vector sizes in multimem instructions #332

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,51 +25,71 @@ struct DeviceMulticastPointerDeviceHandle {
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_CUDA)
template <int NElemPerThread = 4, typename TValue = float4, typename T = float>
template <typename TValue = float4, typename T = float>
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <int NElemPerThread = 4, typename TValue, typename T>
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.st.relaxed.sys.global.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <int NElemPerThread = 4, typename TValue, typename T>
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x),
"r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
Expand Down
3 changes: 2 additions & 1 deletion src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ class NvlsConnection::Impl {
size_t getMinMcGran() { throw notSupportedError; }

private:
Error notSupportedError = Error("NVLS is not supported on this CUDA version", ErrorCode::InvalidUsage);
Error notSupportedError =
Error("NVLS is not supported on this CUDA version (< 12.1) or kernel version (< 5.6.0)", ErrorCode::InvalidUsage);
};
#endif // !(USE_NVLS)

Expand Down
Loading