-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Remove marlin moe templating on thread_m_blocks #8573
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle( | |
|
||
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id | ||
const int threads, // number of threads in a threadblock | ||
const int thread_m_blocks, // number of 16x16 blocks in the m | ||
// dimension (batchsize) of the | ||
// threadblock | ||
const int thread_n_blocks, // same for n dimension (output) | ||
const int thread_k_blocks, // same for k dimension (reduction) | ||
const int stages, // number of stages for the async global->shared | ||
|
@@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, | |
|
||
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id | ||
const int threads, // number of threads in a threadblock | ||
const int thread_m_blocks, // number of 16x16 blocks in the m | ||
// dimension (batchsize) of the | ||
// threadblock | ||
const int thread_n_blocks, // same for n dimension (output) | ||
const int thread_k_blocks, // same for k dimension (reduction) | ||
const int stages, // number of stages for the async global->shared | ||
|
@@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory | |
static constexpr int min_thread_n = 64; | ||
static constexpr int min_thread_k = 64; | ||
|
||
#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ | ||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \ | ||
NUM_THREADS) \ | ||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ | ||
thread_n_blocks == THREAD_N_BLOCKS && \ | ||
thread_k_blocks == THREAD_K_BLOCKS && \ | ||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ | ||
num_threads == NUM_THREADS) { \ | ||
cudaFuncSetAttribute( \ | ||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ | ||
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \ | ||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ | ||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ | ||
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \ | ||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ | ||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ | ||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ | ||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ | ||
replicate_input, apply_weights, m_block, max_par, \ | ||
exec_cfg.max_m_blocks); \ | ||
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ | ||
GROUP_BLOCKS, NUM_THREADS) \ | ||
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ | ||
thread_k_blocks == THREAD_K_BLOCKS && \ | ||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ | ||
num_threads == NUM_THREADS) { \ | ||
cudaFuncSetAttribute( \ | ||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ | ||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \ | ||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ | ||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ | ||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \ | ||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ | ||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ | ||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ | ||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ | ||
replicate_input, apply_weights, m_block, max_par, \ | ||
exec_cfg.max_m_blocks); \ | ||
} | ||
|
||
typedef struct { | ||
|
@@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, | |
return exec_config_t{0, {-1, -1, -1}}; | ||
} | ||
|
||
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
\ | ||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ | ||
\ | ||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ | ||
\ | ||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ | ||
\ | ||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) | ||
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ | ||
\ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ | ||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) | ||
|
||
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, | ||
const void* sorted_ids, const void* topk_weights, | ||
|
@@ -1872,7 +1849,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, | |
for (int m_block = 0; m_block < tot_m_blocks; | ||
m_block += 4 * exec_cfg.max_m_blocks) { | ||
// make it max possible value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we remove the comment as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I was a bit overzealous here -- restoring this so we can bring the unsupported shapes error message back to its former glory |
||
int thread_m_blocks = exec_cfg.max_m_blocks; | ||
|
||
if (false) { | ||
} | ||
|
@@ -1890,7 +1866,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, | |
", has_act_order = " + str(has_act_order) + | ||
", num_groups = " + str(num_groups) + | ||
", group_size = " + str(group_size) + | ||
", thread_m_blocks = " + str(thread_m_blocks) + | ||
", thread_n_blocks = " + str(thread_n_blocks) + | ||
", thread_k_blocks = " + str(thread_k_blocks)); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't all of this be just one
__CALL_IF_MOE
?