From f787d14127cbf3c0f4f0723ebceefdb57c509da9 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 9 May 2024 21:53:28 +0000 Subject: [PATCH 1/7] enable fused topK_softmax kernel for hip --- csrc/cuda_compat.h | 2 + csrc/moe/topk_softmax_kernels.cu | 27 ++++++---- setup.py | 4 +- .../layers/fused_moe/fused_moe.py | 50 ++++++++----------- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 1ebb2e74a82fc..a9febc979ebe1 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -18,8 +18,10 @@ #ifndef USE_ROCM #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 8c65f40fe836a..6ba4fcdb3a3f2 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -19,15 +19,22 @@ #include #include #include +#include "../cuda_compat.h" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace vllm { namespace moe { -static constexpr int WARP_SIZE = 32; - /// Aligned array type template < typename T, @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -383,7 +390,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/setup.py b/setup.py index a66af2c5d556f..5817ca86f88da 100644 --- a/setup.py +++ b/setup.py @@ -382,11 +382,9 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] -if _is_cuda(): - ext_modules.append(CMakeExtension(name="vllm._moe_C")) - if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) + ext_modules.append(CMakeExtension(name="vllm._moe_C")) if _install_punica(): ext_modules.append(CMakeExtension(name="vllm._punica_C")) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bb7938b3715be..6eab50b10a2a7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -11,6 +11,8 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip +import vllm._moe_C as moe_kernels + logger = init_logger(__name__) @@ -319,34 +321,26 @@ def fused_topk( M, _ = hidden_states.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - import vllm._moe_C as moe_kernels - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids From 493a19ed0dbc91c425c82928bd508db77a40270a Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 20 May 2024 17:18:24 +0000 Subject: [PATCH 2/7] NIT: make ruff & yapf happy --- .../model_executor/layers/fused_moe/fused_moe.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6eab50b10a2a7..20a3c9f6f893f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,11 +8,9 @@ import triton import triton.language as tl +import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.utils import is_hip -import vllm._moe_C as moe_kernels - logger = init_logger(__name__) @@ -322,13 +320,13 @@ def fused_topk( M, _ = hidden_states.shape topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) + topk, + dtype=torch.float32, + device=hidden_states.device) topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk, + dtype=torch.int32, + device=hidden_states.device) token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, From 64b305891cadd5bc855952556222b7eb2a5ebf5c Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 20 May 2024 22:37:42 +0000 Subject: [PATCH 3/7] enable moe extension for hip in CMakeLists --- CMakeLists.txt | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 35846fd1cfa99..2abe632aea8b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,6 +308,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) + # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and # there are supported target arches. @@ -317,8 +320,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") add_dependencies(default _punica_C) endif() endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) -endif() From 9740125532643002db155feffccb8cba10ffcd59 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 20 May 2024 22:41:15 +0000 Subject: [PATCH 4/7] Ray workaround for correct import (vllm._moe_C) in fused_moe.py --- Dockerfile.rocm | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 9bfe8446a519d..e30a2aaf30209 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -108,6 +108,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. From d716bace609dd3b4efc8bb8e1647e8e91e4985dc Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 20 May 2024 23:26:30 +0000 Subject: [PATCH 5/7] [fix]: compile _moe_C for cuda or hip only --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5817ca86f88da..ac71fdc600c84 100644 --- a/setup.py +++ b/setup.py @@ -382,9 +382,11 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] +if _is_cuda() or _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._moe_C")) + if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) - ext_modules.append(CMakeExtension(name="vllm._moe_C")) if _install_punica(): ext_modules.append(CMakeExtension(name="vllm._punica_C")) From 87941fde6207b687e008395b0ad9c7bd3b422e99 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 22 May 2024 11:59:22 -0500 Subject: [PATCH 6/7] Update cuda_compat.h make clang-format happy --- csrc/cuda_compat.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index e582ee1450af9..99631788416ef 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -17,11 +17,14 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width) +#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM From a949be9e3f150a311a0f8cc369d4f1b40570f3ff Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 22 May 2024 12:07:45 -0500 Subject: [PATCH 7/7] Update cuda_compat.h clang-format happy v2 --- csrc/cuda_compat.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 99631788416ef..82e55613d915a 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -17,14 +17,14 @@ #endif #ifndef USE_ROCM -#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ - __shfl_xor_sync(uint32_t(-1), var, lane_mask) -#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ - __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else -#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) -#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ - __shfl_xor(var, lane_mask, width) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM