From 41187b412a1474b2c4c7ec15f5f69a67c11c2ede Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 24 Dec 2023 21:27:36 -0800 Subject: [PATCH] Prefer Extern Operators to `nn.SourceModule` Depends on https://github.com/apache/tvm/pull/16274. --- python/mlc_chat/compiler/extern/extern.py | 2 - python/mlc_chat/compiler/extern/flashinfer.py | 93 +------------------ python/mlc_chat/compiler/model/extern_op.py | 32 ++++++- 3 files changed, 32 insertions(+), 95 deletions(-) diff --git a/python/mlc_chat/compiler/extern/extern.py b/python/mlc_chat/compiler/extern/extern.py index bd04d2c6aa..c63ec4a624 100644 --- a/python/mlc_chat/compiler/extern/extern.py +++ b/python/mlc_chat/compiler/extern/extern.py @@ -19,7 +19,6 @@ from tvm.target import Target -from ...support.auto_target import detect_cuda_arch_list from .flashinfer import FlashInfer @@ -69,7 +68,6 @@ def configure(rope_scale: float, rope_theta: float) -> None: if store.flashinfer is not None: assert store.target.kind.name == "cuda" store.flashinfer.configure( - arch_list=detect_cuda_arch_list(store.target), rope_scale=rope_scale, rope_theta=rope_theta, ) diff --git a/python/mlc_chat/compiler/extern/flashinfer.py b/python/mlc_chat/compiler/extern/flashinfer.py index 419ecd5c93..cd4f7805cd 100644 --- a/python/mlc_chat/compiler/extern/flashinfer.py +++ b/python/mlc_chat/compiler/extern/flashinfer.py @@ -1,117 +1,28 @@ """FlashInfer library.""" import dataclasses -from typing import List, Optional - -from tvm.relax.frontend import nn @dataclasses.dataclass class FlashInfer: - """A fast kernel library for LLM inference. - - --- Variables --- - s: sequence length of the current query - t: total sequence length - d: head dimension - h_q: number of heads in query - h_kv: number of heads in key and value - - --- Shapes --- - q: [s, h_q, d] - k: [t, h_kv, d] - v: [t, h_kv, d] - o: [1, s, hidden = h_q * d] - """ + """A fast kernel library for LLM inference.""" rope_scale: float = 1.0 rope_theta: float = 10000.0 - mod: Optional[nn.SourceModule] = None def configure( self, - arch_list: List[int], rope_scale: float, rope_theta: float, ): - """Configure FlashInfer as an nn.SourceModule. + """Configure FlashInfer as an external operator Parameters ---------- - arch_list : List[int] - List of GPU architectures, e.g. [80, 96, 90] - rope_scale : float Scaling factor for the RoPE embedding. rope_theta : float The base period of the RoPE embedding. """ - - # pylint: disable=no-member,unexpected-keyword-arg,no-value-for-parameter - def _infer(q: nn.Tensor, *_args): # pylint: disable=invalid-name - _, s, h_q, d = q.shape # pylint: disable=invalid-name - return nn.Tensor.placeholder((1, s, h_q * d), dtype="float16") - - assert self.mod is None - - compile_options = nn.SourceModule.get_compile_options( - source_format="cu", - tvm_pkg=["flashinfer/include"], - ) - for arch in arch_list: - compile_options += ["-gencode", f"arch=compute_{arch},code=sm_{arch}"] - self.rope_scale = rope_scale self.rope_theta = rope_theta - self.mod = nn.SourceModule( - symbols={ - "FlashInferSinglePrefillWithKVCache": _infer, - "FlashInferSingleDecodeWithKVCache": _infer, - }, - source_code=nn.SourceModule.tvm_home() / "3rdparty/flashinfer/src/tvm_wrapper.cu", - source_format="cu", - compile_options=compile_options, - compiler="nvcc", - ) - nn.add_extern(self.mod) - # pylint: enable=no-member,unexpected-keyword-arg,no-value-for-parameter - - def single_batch( # pylint: disable=invalid-name - self, - q: nn.Tensor, - k: nn.Tensor, - v: nn.Tensor, - ): - """Single batch inference with FlashInfer""" - assert self.mod is not None, "FlashInfer module does not exist" - assert q.dtype == "float16" and q.ndim == 4 - assert k.dtype == "float16" and k.ndim == 3 - assert v.dtype == "float16" and v.ndim == 3 - _, s, _, _ = q.shape - casual = 1 # True - qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim - rotary_mode = 0 # "kNone" - allow_fp16_qk_reduction = 1 # True - # Decoding - if isinstance(s, int) and s == 1: - return self.mod["FlashInferSingleDecodeWithKVCache"]( - q, - k, - v, - qkv_layout, - rotary_mode, - self.rope_scale, - self.rope_theta, - ) - # Prefilling - return self.mod["FlashInferSinglePrefillWithKVCache"]( - q, - k, - v, - casual, - qkv_layout, - rotary_mode, - allow_fp16_qk_reduction, - self.rope_scale, - self.rope_theta, - ) diff --git a/python/mlc_chat/compiler/model/extern_op.py b/python/mlc_chat/compiler/model/extern_op.py index 88bdee8d60..75dd48ef2c 100644 --- a/python/mlc_chat/compiler/model/extern_op.py +++ b/python/mlc_chat/compiler/model/extern_op.py @@ -14,7 +14,7 @@ ] -def attention( # pylint: disable=invalid-name +def attention( # pylint: disable=invalid-name,too-many-locals q: nn.Tensor, k: nn.Tensor, v: nn.Tensor, @@ -63,7 +63,35 @@ def attention( # pylint: disable=invalid-name and k.dtype == "float16" and v.dtype == "float16" ): - return extern_store.flashinfer.single_batch(q, k, v) + rope_scale = extern_store.flashinfer.rope_scale + rope_theta = extern_store.flashinfer.rope_theta + qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim + rotary_mode = 0 # "kNone" + casual = 1 # True + fp16_qk = 1 # True + + def _decode(): + return op.extern( # pylint: disable=no-member + name="flashinfer.single_decode", + args=[q, k, v, qkv_layout, rotary_mode, rope_scale, rope_theta], + out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"), + ) + + def _prefill(): + return op.extern( # pylint: disable=no-member + name="flashinfer.single_prefill", + args=[q, k, v, casual, qkv_layout, rotary_mode, fp16_qk, rope_scale, rope_theta], + out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"), + ) + + if isinstance(s, int) and s == 1: + func = "decode" + else: + func = "prefill" + return { + "decode": _decode, + "prefill": _prefill, + }[func]() # Fallback Implementation k = op.reshape(k, [b, t, h_kv, d])