Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat authored and robertgshaw2-redhat committed May 19, 2024
1 parent edd9e90 commit 32314e5
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return res;
}

// Constructs destination register by taking bytes from 2 sources (based on mask)
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
Expand Down Expand Up @@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
};

// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
Expand Down Expand Up @@ -1275,13 +1276,22 @@ typedef struct {
thread_config_t tb_cfg;
} exec_config_t;

thread_config_t thread_configs[] = {
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority

// thread_k, thread_n, num_threads
{64, 256, 256}, // Default (max cache usage)
{64, 128, 128}, // Reduce N, reduce warps
{128, 64, 128}, // Reduce N more, but increase K
{128, 128, 256},
{64, 128, 128},
{128, 64, 128},
};

thread_config_t large_batch_thread_configs[] = {
// Ordered by priority

// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128},
{128, 64, 128},

};

Expand Down Expand Up @@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
for (auto th_config : thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}

Expand Down Expand Up @@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
Expand Down

0 comments on commit 32314e5

Please sign in to comment.