diff --git a/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py b/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py index 89fa746811110..0fd39e12256ab 100644 --- a/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py @@ -1,5 +1,4 @@ """Fused MoE kernel.""" - import functools import json import os @@ -120,7 +119,7 @@ def fused_moe_kernel( ) if use_fp8: - #a_scale = tl.load(a_scale_ptr) + # a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- @@ -138,9 +137,33 @@ def fused_moe_kernel( mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ).to(tl.float16) - #a = tl.extra.cuda.convert_uint8_as_fp8e4m3_to_float16(a) + # a = tl.extra.cuda.convert_uint8_as_fp8e4m3_to_float16(a) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.extra.cuda.convert_uint8_as_fp8e4m3_to_float16(b) * b_scale.to(tl.float16) + + b = tl.inline_asm_elementwise( + asm = "{ \n" + ".reg .b32 a<2>, b<2>; \n" # if input = 0xf1f2f3f4 + ".reg .b32 r0, r1; \n" + "prmt.b32 a0, 0, $1, 0x5040; \n" # a0 = 0xf300f400 + "prmt.b32 a1, 0, $1, 0x7060; \n" # a1 = 0xf100f200 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" # b0 = a0 & 0x7fff7fff + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" # (strip sign) + "shr.b32 b0, b0, 1; \n" # b0 >>= 1 + "shr.b32 b1, b1, 1; \n" # shift into fp16 position + "add.u32 b0, b0, 0x20002000; \n" # b0.exp += 2**4-2**3 + # exponent compensate = 8 + "add.u32 b1, b1, 0x20002000; \n" # b1 += 8<<10 | 8<<10<<16 + "lop3.b32 r0, b0, 0x80008000, a0, 0xf8; \n" # out0 = b0|(0x80008000&a0) + "lop3.b32 r1, b1, 0x80008000, a1, 0xf8; \n" # (restore sign) + "mov.b64 $0, {r0, r1}; \n" + "} \n", + constraints="=l,r", + args = [b], + dtype = tl.float16, + is_pure = True, + pack = 4 + ) * b_scale.to(tl.float16) + # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -246,10 +269,10 @@ def invoke_fused_moe_kernel( assert A_scale is None assert B_scale is None else: - #A, A_scale = ops.scaled_fp8_quant(A, A_scale) + # A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - - #A = triton.reinterpret(A, dtype=tl.uint8) if use_fp8 else A + + # A = triton.reinterpret(A, dtype=tl.uint8) if use_fp8 else A B = triton.reinterpret(B, dtype=tl.uint8) if use_fp8 else B grid = lambda META: ( diff --git a/vllm/model_executor/layers/fused_moe/replace_triton_cuda.py b/vllm/model_executor/layers/fused_moe/replace_triton_cuda.py deleted file mode 100644 index d897b32342db1..0000000000000 --- a/vllm/model_executor/layers/fused_moe/replace_triton_cuda.py +++ /dev/null @@ -1,20 +0,0 @@ -import pkg_resources -import shutil -import os - -def replace_triton_cuda(): - - def get_package_path(package_name): - return pkg_resources.get_distribution(package_name).location - - def get_package_version(package_name): - return pkg_resources.get_distribution(package_name).version - - assert get_package_version('triton') == "2.3.0" or get_package_version('triton') == "2.2.0" - - cur_folder_cuda_py = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'triton_cuda.py') - target_folder_cuda_py = os.path.join(get_package_path('triton'), 'triton', 'language', 'extra', 'cuda.py') - shutil.copyfile(cur_folder_cuda_py, target_folder_cuda_py) - -if __name__ == "__main__": - replace_triton_cuda() \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/triton_cuda.py b/vllm/model_executor/layers/fused_moe/triton_cuda.py deleted file mode 100644 index f66c2618b12f5..0000000000000 --- a/vllm/model_executor/layers/fused_moe/triton_cuda.py +++ /dev/null @@ -1,94 +0,0 @@ -from .. import core - -@core.extern -def globaltimer(_builder=None): - return core.inline_asm_elementwise( - "mov.u64 $0, %globaltimer;", - "=l", - [], - dtype=core.int64, - is_pure=False, - pack=1, - _builder=_builder, - ) - - -@core.extern -def smid(_builder=None): - return core.inline_asm_elementwise( - "mov.u32 $0, %smid;", - "=r", - [], - dtype=core.int32, - is_pure=True, - pack=1, - _builder=_builder, - ) - - -@core.builtin -def num_threads(_builder=None): - return core.constexpr(_builder.target.num_warps * 32) - - -@core.builtin -def num_warps(_builder=None): - return core.constexpr(_builder.options.num_warps) - - -@core.builtin -def convert_uint8_as_fp8e4m3_to_float16(r0, _builder=None): - return core.inline_asm_elementwise( - "{ \n" - ".reg .b32 a<2>, b<2>; \n" # if input = 0xf1f2f3f4 - ".reg .b32 r0, r1; \n" - "prmt.b32 a0, 0, $1, 0x5040; \n" # a0 = 0xf300f400 - "prmt.b32 a1, 0, $1, 0x7060; \n" # a1 = 0xf100f200 - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" # b0 = a0 & 0x7fff7fff - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" # (strip sign) - "shr.b32 b0, b0, 1; \n" # b0 >>= 1 - "shr.b32 b1, b1, 1; \n" # shift into fp16 position - "add.u32 b0, b0, 0x20002000; \n" # b0.exp += 2**4-2**3 - # exponent compensate = 8 - "add.u32 b1, b1, 0x20002000; \n" # b1 += 8<<10 | 8<<10<<16 - "lop3.b32 r0, b0, 0x80008000, a0, 0xf8; \n" # out0 = b0|(0x80008000&a0) - "lop3.b32 r1, b1, 0x80008000, a1, 0xf8; \n" # (restore sign) - "mov.b64 $0, {r0, r1}; \n" - "} \n", - "=l,r", - [r0], - dtype=core.float16, - is_pure=True, - pack=4, - _builder=_builder, - ) - - - -# WARN: subnormal (0bs0000xxx) are not handled -@core.builtin -def convert_uint8_as_fp8e4m3_to_bfloat16(r0, _builder=None): - return core.inline_asm_elementwise( - "{ \n" - ".reg .b32 a<2>, b<2>; \n" # if input = 0xf1f2f3f4 - ".reg .b32 r0, r1; \n" - "prmt.b32 a0, 0, $1, 0x5040; \n" # a0 = 0xf300f400 - "prmt.b32 a1, 0, $1, 0x7060; \n" # a1 = 0xf100f200 - "and.b32 b0, a0, 0x7fff7fff; \n" # b0 = a0 & 0x7fff7fff - "and.b32 b1, a1, 0x7fff7fff; \n" # (strip sign) - "shr.b32 b0, b0, 4; \n" # b0 >>= 4 - "shr.b32 b1, b1, 4; \n" # shift into fp16 position - "add.u32 b0, b0, 0x3c003c00; \n" # b0.exp += 2**7-2**3 - # exponent compensate = 120 - "add.u32 b1, b1, 0x3c003c00; \n" # b1 += 120<<7 | 120<<7<<16 - "lop3.b32 r0, b0, 0x80008000, a0, 0xf8; \n" # out0 = b0|(0x80008000&a0) - "lop3.b32 r1, b1, 0x80008000, a1, 0xf8; \n" # (restore sign) - "mov.b64 $0, {r0, r1}; \n" - "} \n", - "=l,r", - [r0], - dtype=core.bfloat16, - is_pure=True, - pack=4, - _builder=_builder, - )