Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[Cherrypick, Bugfix] Fix marlin 2:4 kernel crash on H100 #245

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions csrc/quantization/marlin/sparse/common/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,13 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
);
}

// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
Expand Down
10 changes: 5 additions & 5 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ __global__ void Marlin_24(
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
Expand All @@ -401,15 +401,15 @@ __global__ void Marlin_24(
#pragma unroll
for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred)
cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
Expand Down Expand Up @@ -763,12 +763,12 @@ __global__ void Marlin_24(
if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
}
Expand Down