Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
wip

wip

done

remove flashinfer

add doc

upd

fix lint
  • Loading branch information
cyx-6 authored and junrushao committed Dec 15, 2023
1 parent a7a1681 commit a5c009c
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import compiler_pass
from .compile import CompileArgs, compile # pylint: disable=redefined-builtin
from .convert_weight import ConversionArgs, convert_weight
from .extern import FlashInfer
from .flags_model_config_override import ModelConfigOverride
from .flags_optimization import OptimizationFlags
from .gen_config import CONV_TEMPLATES, gen_config
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_chat/compiler/extern/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
Extern module for compiler.
"""
from .flashinfer import FlashInfer
204 changes: 204 additions & 0 deletions python/mlc_chat/compiler/extern/flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
FlashInfer source module.
"""
import logging
import os
from enum import IntEnum

import tvm
from tvm._ffi.libinfo import find_include_path
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import spec

logger = logging.getLogger(__name__)


class _FlashInfer(nn.SourceModule):
class QKVLayout(IntEnum):
"""The layout of Q, K and V tensor. N for seq_len, H for num_heads and D for dim_head."""

NHD = 0
HND = 1

class RotaryMode(IntEnum):
"""The rotary embedding mode. kNone is no rotary embedding and
kLlama is Llama style rotary embedding."""

kNone = 0 # pylint: disable=invalid-name
kLlama = 1 # pylint: disable=invalid-name

def __init__(self) -> None:
tvm_path = os.path.join(os.path.dirname(tvm.__file__), "..", "..")
tvm_3rdparty_path = os.path.join(tvm_path, "3rdparty")
flashinfer_path = os.path.join(tvm_3rdparty_path, "flashinfer")
with open(
os.path.join(flashinfer_path, "src", "tvm_wrapper.cu"), "r", encoding="utf-8"
) as f:
source_code = f.read()
tvm_include_path = find_include_path(optional=True)
if tvm_include_path is None:
tvm_include_path = [
os.path.join(tvm_path, "include"),
os.path.join(tvm_3rdparty_path, "dlpack", "include"),
os.path.join(tvm_3rdparty_path, "dmlc-core", "include"),
]
compile_options = [
"-c",
"-arch=native",
"-O3",
"-x",
"cu",
"-I",
f"{flashinfer_path}/include",
"-DDMLC_USE_FOPEN64=0",
"-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>",
]
for p in tvm_include_path:
compile_options.extend(["-I", p])
compile_options.append("-Xcompiler=-fPIC")
super().__init__(
source_code=source_code,
source_format="cu",
functions={
"FlashInferSinglePrefillWithKVCache": spec.ExternFunctionSpec(
args=[
spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"), # q
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # k
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # v
spec.ConstInt(), # causal
spec.ConstInt(), # qkv_layout
spec.ConstInt(), # rotary_mode
spec.ConstInt(), # allow_fp16_qk_reduction
spec.ConstFloat(), # rope_scale
spec.ConstFloat(), # rope_theta
spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"), # o
],
ret=spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"),
),
"FlashInferSingleDecodeWithKVCache": spec.ExternFunctionSpec(
args=[
spec.Tensor(("num_qo_heads", "head_dim"), "float16"), # q
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # k
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # v
spec.ConstInt(), # qkv_layout
spec.ConstInt(), # rotary_mode
spec.ConstFloat(), # rope_scale
spec.ConstFloat(), # rope_theta
spec.Tensor(("num_qo_heads", "head_dim"), "float16"), # o
],
ret=spec.Tensor(("num_qo_heads", "head_dim"), "float16"),
),
},
compile_options=compile_options,
compiler="nvcc",
output_format="obj",
)

def single_decode_with_kvcache( # pylint: disable=too-many-arguments
self,
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
qkv_layout: QKVLayout = QKVLayout.NHD,
rotary_mode: RotaryMode = RotaryMode.kNone,
rope_scale: float = 1.0,
rope_theta: float = 1e4,
):
"""
FlashInfer single-batched decode kernel with kvcache.
Parameters
----------
q : nn.Tensor
The input Q tensor.
k : nn.Tensor
The input K tensor.
v : nn.Tensor
The input V tensor.
qkv_layout : QKVLayout
The layout type of Q, K and V.
rotary_mode: RotaryMode
The rotary mode of rotary embedding.
rope_scale: float
The rotary embedding scale.
rope_theta: float
The rotary embedding theta.
Returns
-------
ret : nn.Tensor
The output tensor of FlashInfer single-batched decode with kvcache.
"""
return self.get_extern_func("FlashInferSingleDecodeWithKVCache")(
q, k, v, int(qkv_layout), int(rotary_mode), rope_scale, rope_theta
)

def single_prefill_with_kvcache( # pylint: disable=too-many-arguments
self,
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
causal: bool = True,
qkv_layout: QKVLayout = QKVLayout.NHD,
rotary_mode: RotaryMode = RotaryMode.kNone,
allow_fp16_qk_reduction: bool = False,
rope_scale: float = 1.0,
rope_theta: float = 1e4,
):
"""
FlashInfer single-batched prefill kernel with kvcache.
Parameters
----------
q : nn.Tensor
The input Q tensor.
k : nn.Tensor
The input K tensor.
v : nn.Tensor
The input V tensor.
causal : bool
If the attention is causal.
qkv_layout : QKVLayout
The layout type of Q, K and V.
rotary_mode: RotaryMode
The rotary mode of rotary embedding.
allow_fp16_qk_reduction : bool
If the Q and K matmul reduction is computed under float16.
rope_scale: float
The rotary embedding scale.
rope_theta: float
The rotary embedding theta.
Returns
-------
ret : nn.Tensor
The output tensor of FlashInfer single-batched prefill with kvcache.
"""
return self.get_extern_func("FlashInferSinglePrefillWithKVCache")(
q,
k,
v,
causal,
int(qkv_layout),
int(rotary_mode),
allow_fp16_qk_reduction,
rope_scale,
rope_theta,
)


FlashInfer = _FlashInfer()
47 changes: 46 additions & 1 deletion python/mlc_chat/compiler/model/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from tvm import te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.target import Target

from ....support import logging
from ....support.config import ConfigBase
from ....support.style import bold
from ... import tensor_parallel as tp
from ...extern.flashinfer import FlashInfer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,6 +140,8 @@ def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding):
self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim])
self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim])

self.flashinfer = FlashInfer

def forward( # pylint: disable=too-many-locals
self,
hidden_states: Tensor,
Expand All @@ -151,15 +155,56 @@ def forward( # pylint: disable=too-many-locals
qkv = self.qkv_proj(hidden_states)
qkv = op.reshape(qkv, (b, s, h_q + 2 * h_kv, d))
q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2)
q, k = self.rotary_embedding(q, k, t - s)
# q, k = self.rotary_embedding(q, k, t - s)

self.k_cache.append(op.squeeze(k, axis=0))
self.v_cache.append(op.squeeze(v, axis=0))
k = op.reshape(self.k_cache.view(t), (b, t, h_kv, d))
v = op.reshape(self.v_cache.view(t), (b, t, h_kv, d))

current_target = Target.current(allow_none=True)
if current_target and current_target.kind.name == "cuda":
# enable FlashInfer if target is cuda
if s == 1:
# FlashInfer decode
q = op.reshape(q, (h_q, d))
k = op.reshape(k, (t, h_kv, d))
v = op.reshape(v, (t, h_kv, d))
attn = self.flashinfer.single_decode_with_kvcache(
q.astype("float16"),
k.astype("float16"),
v.astype("float16"),
qkv_layout=FlashInfer.QKVLayout.NHD,
rotary_mode=FlashInfer.RotaryMode.kLlama,
rope_scale=1,
rope_theta=1e4,
)
attn = op.reshape(attn.astype(q.dtype), (b, s, self.hidden_size))
return self.o_proj(attn)
# FlashInfer prefill
q = op.reshape(q, (s, h_q, d))
k = op.reshape(k, (t, h_kv, d))
v = op.reshape(v, (t, h_kv, d))
attn = self.flashinfer.single_prefill_with_kvcache(
q.astype("float16"),
k.astype("float16"),
v.astype("float16"),
causal=True,
qkv_layout=FlashInfer.QKVLayout.NHD,
rotary_mode=FlashInfer.RotaryMode.kLlama,
allow_fp16_qk_reduction=False,
rope_scale=1,
rope_theta=1e4,
)
attn = op.reshape(attn.astype(q.dtype), (b, s, self.hidden_size))
return self.o_proj(attn)

if h_kv != h_q:
k = k.repeat(h_q // h_kv, axis=2)
v = v.repeat(h_q // h_kv, axis=2)

q, k = self.rotary_embedding(q, k, t - s)

q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
Expand Down

0 comments on commit a5c009c

Please sign in to comment.