From 32314e5002ad3c3b157b5ba14af1e45e8b6c52fa Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Wed, 8 May 2024 20:14:31 -0400 Subject: [PATCH] [Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 48 ++++++++++++++------ 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fd0837f0cb39c..9c6bff000e916 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -115,7 +115,8 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on mask) +// Constructs destination register by taking bytes from 2 sources (based on +// mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; @@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -1275,13 +1276,22 @@ typedef struct { thread_config_t tb_cfg; } exec_config_t; -thread_config_t thread_configs[] = { +thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default (max cache usage) - {64, 128, 128}, // Reduce N, reduce warps - {128, 64, 128}, // Reduce N more, but increase K + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, }; @@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { - for (auto th_config : thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } @@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) CALL_IF(4, 8, 4, 128) CALL_IF(4, 4, 8, 128) CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else {