Skip to content

Commit

Permalink
Merge pull request vllm-project#11 from wenxcs/wenxh/fp8-on-a100-v1-pr
Browse files Browse the repository at this point in the history
Wenxh/fp8 on a100 v1 pr
  • Loading branch information
xiaoxiawu-microsoft authored May 31, 2024
2 parents 94d4614 + 22a9f82 commit 03e3bda
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 121 deletions.
37 changes: 30 additions & 7 deletions vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Fused MoE kernel."""

import functools
import json
import os
Expand Down Expand Up @@ -120,7 +119,6 @@ def fused_moe_kernel(
)

if use_fp8:
#a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)

# -----------------------------------------------------------
Expand All @@ -138,9 +136,34 @@ def fused_moe_kernel(
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(tl.float16)
#a = tl.extra.cuda.convert_uint8_as_fp8e4m3_to_float16(a)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.extra.cuda.convert_uint8_as_fp8e4m3_to_float16(b) * b_scale.to(tl.float16)

# todo(wenxh): there is a bug in triton 2.2/2.3 that only "=l" works, "=r"
# will result error in llvm check(low level bug).
b = tl.inline_asm_elementwise(
asm = "{ \n"
".reg .b32 a<2>, b<2>; \n" # if input = 0xf1f2f3f4
".reg .b32 r0, r1; \n"
"prmt.b32 a0, 0, $1, 0x5040; \n" # a0 = 0xf300f400
"prmt.b32 a1, 0, $1, 0x7060; \n" # a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" # b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" # (strip sign)
"shr.b32 b0, b0, 1; \n" # b0 >>= 1
"shr.b32 b1, b1, 1; \n" # shift into fp16 position
"add.u32 b0, b0, 0x20002000; \n" # b0.exp += 2**4-2**3
# exponent compensate = 8
"add.u32 b1, b1, 0x20002000; \n" # b1 += 8<<10 | 8<<10<<16
"lop3.b32 r0, b0, 0x80008000, a0, 0xf8; \n" # out0 = b0|(0x80008000&a0)
"lop3.b32 r1, b1, 0x80008000, a1, 0xf8; \n" # (restore sign)
"mov.b64 $0, {r0, r1}; \n"
"} \n",
constraints="=l,r",
args = [b],
dtype = tl.float16,
is_pure = True,
pack = 4
) * b_scale.to(tl.float16)

# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
Expand Down Expand Up @@ -246,10 +269,10 @@ def invoke_fused_moe_kernel(
assert A_scale is None
assert B_scale is None
else:
#A, A_scale = ops.scaled_fp8_quant(A, A_scale)
# A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
#A = triton.reinterpret(A, dtype=tl.uint8) if use_fp8 else A

# A = triton.reinterpret(A, dtype=tl.uint8) if use_fp8 else A
B = triton.reinterpret(B, dtype=tl.uint8) if use_fp8 else B

grid = lambda META: (
Expand Down
20 changes: 0 additions & 20 deletions vllm/model_executor/layers/fused_moe/replace_triton_cuda.py

This file was deleted.

94 changes: 0 additions & 94 deletions vllm/model_executor/layers/fused_moe/triton_cuda.py

This file was deleted.

0 comments on commit 03e3bda

Please sign in to comment.