Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer Extern Operators to nn.SourceModule #1488

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/mlc_chat/compiler/extern/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from tvm.target import Target

from ...support.auto_target import detect_cuda_arch_list
from .flashinfer import FlashInfer


Expand Down Expand Up @@ -69,7 +68,6 @@ def configure(rope_scale: float, rope_theta: float) -> None:
if store.flashinfer is not None:
assert store.target.kind.name == "cuda"
store.flashinfer.configure(
arch_list=detect_cuda_arch_list(store.target),
rope_scale=rope_scale,
rope_theta=rope_theta,
)
93 changes: 2 additions & 91 deletions python/mlc_chat/compiler/extern/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,28 @@
"""FlashInfer library."""
import dataclasses
from typing import List, Optional

from tvm.relax.frontend import nn


@dataclasses.dataclass
class FlashInfer:
"""A fast kernel library for LLM inference.

--- Variables ---
s: sequence length of the current query
t: total sequence length
d: head dimension
h_q: number of heads in query
h_kv: number of heads in key and value

--- Shapes ---
q: [s, h_q, d]
k: [t, h_kv, d]
v: [t, h_kv, d]
o: [1, s, hidden = h_q * d]
"""
"""A fast kernel library for LLM inference."""

rope_scale: float = 1.0
rope_theta: float = 10000.0
mod: Optional[nn.SourceModule] = None

def configure(
self,
arch_list: List[int],
rope_scale: float,
rope_theta: float,
):
"""Configure FlashInfer as an nn.SourceModule.
"""Configure FlashInfer as an external operator

Parameters
----------
arch_list : List[int]
List of GPU architectures, e.g. [80, 96, 90]

rope_scale : float
Scaling factor for the RoPE embedding.

rope_theta : float
The base period of the RoPE embedding.
"""

# pylint: disable=no-member,unexpected-keyword-arg,no-value-for-parameter
def _infer(q: nn.Tensor, *_args): # pylint: disable=invalid-name
_, s, h_q, d = q.shape # pylint: disable=invalid-name
return nn.Tensor.placeholder((1, s, h_q * d), dtype="float16")

assert self.mod is None

compile_options = nn.SourceModule.get_compile_options(
source_format="cu",
tvm_pkg=["flashinfer/include"],
)
for arch in arch_list:
compile_options += ["-gencode", f"arch=compute_{arch},code=sm_{arch}"]

self.rope_scale = rope_scale
self.rope_theta = rope_theta
self.mod = nn.SourceModule(
symbols={
"FlashInferSinglePrefillWithKVCache": _infer,
"FlashInferSingleDecodeWithKVCache": _infer,
},
source_code=nn.SourceModule.tvm_home() / "3rdparty/flashinfer/src/tvm_wrapper.cu",
source_format="cu",
compile_options=compile_options,
compiler="nvcc",
)
nn.add_extern(self.mod)
# pylint: enable=no-member,unexpected-keyword-arg,no-value-for-parameter

def single_batch( # pylint: disable=invalid-name
self,
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
):
"""Single batch inference with FlashInfer"""
assert self.mod is not None, "FlashInfer module does not exist"
assert q.dtype == "float16" and q.ndim == 4
assert k.dtype == "float16" and k.ndim == 3
assert v.dtype == "float16" and v.ndim == 3
_, s, _, _ = q.shape
casual = 1 # True
qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim
rotary_mode = 0 # "kNone"
allow_fp16_qk_reduction = 1 # True
# Decoding
if isinstance(s, int) and s == 1:
return self.mod["FlashInferSingleDecodeWithKVCache"](
q,
k,
v,
qkv_layout,
rotary_mode,
self.rope_scale,
self.rope_theta,
)
# Prefilling
return self.mod["FlashInferSinglePrefillWithKVCache"](
q,
k,
v,
casual,
qkv_layout,
rotary_mode,
allow_fp16_qk_reduction,
self.rope_scale,
self.rope_theta,
)
32 changes: 30 additions & 2 deletions python/mlc_chat/compiler/model/extern_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
]


def attention( # pylint: disable=invalid-name
def attention( # pylint: disable=invalid-name,too-many-locals
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
Expand Down Expand Up @@ -63,7 +63,35 @@ def attention( # pylint: disable=invalid-name
and k.dtype == "float16"
and v.dtype == "float16"
):
return extern_store.flashinfer.single_batch(q, k, v)
rope_scale = extern_store.flashinfer.rope_scale
rope_theta = extern_store.flashinfer.rope_theta
qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim
rotary_mode = 0 # "kNone"
casual = 1 # True
fp16_qk = 1 # True

def _decode():
return op.extern( # pylint: disable=no-member
name="flashinfer.single_decode",
args=[q, k, v, qkv_layout, rotary_mode, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

def _prefill():
return op.extern( # pylint: disable=no-member
name="flashinfer.single_prefill",
args=[q, k, v, casual, qkv_layout, rotary_mode, fp16_qk, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

if isinstance(s, int) and s == 1:
func = "decode"
else:
func = "prefill"
return {
"decode": _decode,
"prefill": _prefill,
}[func]()

# Fallback Implementation
k = op.reshape(k, [b, t, h_kv, d])
Expand Down