Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fix marlin 2:4 kernel crash on H100
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed May 16, 2024
1 parent 6334dd3 commit 2a11024
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
16 changes: 5 additions & 11 deletions csrc/quantization/marlin/sparse/common/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,13 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
);
}

// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
Expand Down
10 changes: 5 additions & 5 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ __global__ void Marlin_24(
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
Expand All @@ -401,15 +401,15 @@ __global__ void Marlin_24(
#pragma unroll
for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred)
cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
Expand Down Expand Up @@ -763,12 +763,12 @@ __global__ void Marlin_24(
if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
}
Expand Down

1 comment on commit 2a11024

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: 2a11024 Previous: 6334dd3 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.3.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 3.8384854487498106 prompts/s 3.8418198063652103 prompts/s 1.00
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.3.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 1473.9784123199272 tokens/s 1475.2588056442407 tokens/s 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.