Skip to content

Commit

Permalink
Prefer Extern Operators to nn.SourceModule (#1488)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Dec 26, 2023
1 parent 88572b9 commit 3e3ccf9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 95 deletions.
2 changes: 0 additions & 2 deletions python/mlc_chat/compiler/extern/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from tvm.target import Target

from ...support.auto_target import detect_cuda_arch_list
from .flashinfer import FlashInfer


Expand Down Expand Up @@ -69,7 +68,6 @@ def configure(rope_scale: float, rope_theta: float) -> None:
if store.flashinfer is not None:
assert store.target.kind.name == "cuda"
store.flashinfer.configure(
arch_list=detect_cuda_arch_list(store.target),
rope_scale=rope_scale,
rope_theta=rope_theta,
)
93 changes: 2 additions & 91 deletions python/mlc_chat/compiler/extern/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,28 @@
"""FlashInfer library."""
import dataclasses
from typing import List, Optional

from tvm.relax.frontend import nn


@dataclasses.dataclass
class FlashInfer:
"""A fast kernel library for LLM inference.
--- Variables ---
s: sequence length of the current query
t: total sequence length
d: head dimension
h_q: number of heads in query
h_kv: number of heads in key and value
--- Shapes ---
q: [s, h_q, d]
k: [t, h_kv, d]
v: [t, h_kv, d]
o: [1, s, hidden = h_q * d]
"""
"""A fast kernel library for LLM inference."""

rope_scale: float = 1.0
rope_theta: float = 10000.0
mod: Optional[nn.SourceModule] = None

def configure(
self,
arch_list: List[int],
rope_scale: float,
rope_theta: float,
):
"""Configure FlashInfer as an nn.SourceModule.
"""Configure FlashInfer as an external operator
Parameters
----------
arch_list : List[int]
List of GPU architectures, e.g. [80, 96, 90]
rope_scale : float
Scaling factor for the RoPE embedding.
rope_theta : float
The base period of the RoPE embedding.
"""

# pylint: disable=no-member,unexpected-keyword-arg,no-value-for-parameter
def _infer(q: nn.Tensor, *_args): # pylint: disable=invalid-name
_, s, h_q, d = q.shape # pylint: disable=invalid-name
return nn.Tensor.placeholder((1, s, h_q * d), dtype="float16")

assert self.mod is None

compile_options = nn.SourceModule.get_compile_options(
source_format="cu",
tvm_pkg=["flashinfer/include"],
)
for arch in arch_list:
compile_options += ["-gencode", f"arch=compute_{arch},code=sm_{arch}"]

self.rope_scale = rope_scale
self.rope_theta = rope_theta
self.mod = nn.SourceModule(
symbols={
"FlashInferSinglePrefillWithKVCache": _infer,
"FlashInferSingleDecodeWithKVCache": _infer,
},
source_code=nn.SourceModule.tvm_home() / "3rdparty/flashinfer/src/tvm_wrapper.cu",
source_format="cu",
compile_options=compile_options,
compiler="nvcc",
)
nn.add_extern(self.mod)
# pylint: enable=no-member,unexpected-keyword-arg,no-value-for-parameter

def single_batch( # pylint: disable=invalid-name
self,
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
):
"""Single batch inference with FlashInfer"""
assert self.mod is not None, "FlashInfer module does not exist"
assert q.dtype == "float16" and q.ndim == 4
assert k.dtype == "float16" and k.ndim == 3
assert v.dtype == "float16" and v.ndim == 3
_, s, _, _ = q.shape
casual = 1 # True
qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim
rotary_mode = 0 # "kNone"
allow_fp16_qk_reduction = 1 # True
# Decoding
if isinstance(s, int) and s == 1:
return self.mod["FlashInferSingleDecodeWithKVCache"](
q,
k,
v,
qkv_layout,
rotary_mode,
self.rope_scale,
self.rope_theta,
)
# Prefilling
return self.mod["FlashInferSinglePrefillWithKVCache"](
q,
k,
v,
casual,
qkv_layout,
rotary_mode,
allow_fp16_qk_reduction,
self.rope_scale,
self.rope_theta,
)
32 changes: 30 additions & 2 deletions python/mlc_chat/compiler/model/extern_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
]


def attention( # pylint: disable=invalid-name
def attention( # pylint: disable=invalid-name,too-many-locals
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
Expand Down Expand Up @@ -63,7 +63,35 @@ def attention( # pylint: disable=invalid-name
and k.dtype == "float16"
and v.dtype == "float16"
):
return extern_store.flashinfer.single_batch(q, k, v)
rope_scale = extern_store.flashinfer.rope_scale
rope_theta = extern_store.flashinfer.rope_theta
qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim
rotary_mode = 0 # "kNone"
casual = 1 # True
fp16_qk = 1 # True

def _decode():
return op.extern( # pylint: disable=no-member
name="flashinfer.single_decode",
args=[q, k, v, qkv_layout, rotary_mode, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

def _prefill():
return op.extern( # pylint: disable=no-member
name="flashinfer.single_prefill",
args=[q, k, v, casual, qkv_layout, rotary_mode, fp16_qk, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

if isinstance(s, int) and s == 1:
func = "decode"
else:
func = "prefill"
return {
"decode": _decode,
"prefill": _prefill,
}[func]()

# Fallback Implementation
k = op.reshape(k, [b, t, h_kv, d])
Expand Down

0 comments on commit 3e3ccf9

Please sign in to comment.