-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wip wip done remove flashinfer add doc upd fix lint
- Loading branch information
Showing
4 changed files
with
255 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
""" | ||
Extern module for compiler. | ||
""" | ||
from .flashinfer import FlashInfer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
""" | ||
FlashInfer source module. | ||
""" | ||
import logging | ||
import os | ||
from enum import IntEnum | ||
|
||
import tvm | ||
from tvm._ffi.libinfo import find_include_path | ||
from tvm.relax.frontend import nn | ||
from tvm.relax.frontend.nn import spec | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class _FlashInfer(nn.SourceModule): | ||
class QKVLayout(IntEnum): | ||
"""The layout of Q, K and V tensor. N for seq_len, H for num_heads and D for dim_head.""" | ||
|
||
NHD = 0 | ||
HND = 1 | ||
|
||
class RotaryMode(IntEnum): | ||
"""The rotary embedding mode. kNone is no rotary embedding and | ||
kLlama is Llama style rotary embedding.""" | ||
|
||
kNone = 0 # pylint: disable=invalid-name | ||
kLlama = 1 # pylint: disable=invalid-name | ||
|
||
def __init__(self) -> None: | ||
tvm_path = os.path.join(os.path.dirname(tvm.__file__), "..", "..") | ||
tvm_3rdparty_path = os.path.join(tvm_path, "3rdparty") | ||
flashinfer_path = os.path.join(tvm_3rdparty_path, "flashinfer") | ||
with open( | ||
os.path.join(flashinfer_path, "src", "tvm_wrapper.cu"), "r", encoding="utf-8" | ||
) as f: | ||
source_code = f.read() | ||
tvm_include_path = find_include_path(optional=True) | ||
if tvm_include_path is None: | ||
tvm_include_path = [ | ||
os.path.join(tvm_path, "include"), | ||
os.path.join(tvm_3rdparty_path, "dlpack", "include"), | ||
os.path.join(tvm_3rdparty_path, "dmlc-core", "include"), | ||
] | ||
compile_options = [ | ||
"-c", | ||
"-arch=native", | ||
"-O3", | ||
"-x", | ||
"cu", | ||
"-I", | ||
f"{flashinfer_path}/include", | ||
"-DDMLC_USE_FOPEN64=0", | ||
"-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>", | ||
] | ||
for p in tvm_include_path: | ||
compile_options.extend(["-I", p]) | ||
compile_options.append("-Xcompiler=-fPIC") | ||
super().__init__( | ||
source_code=source_code, | ||
source_format="cu", | ||
functions={ | ||
"FlashInferSinglePrefillWithKVCache": spec.ExternFunctionSpec( | ||
args=[ | ||
spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"), # q | ||
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # k | ||
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # v | ||
spec.ConstInt(), # causal | ||
spec.ConstInt(), # qkv_layout | ||
spec.ConstInt(), # rotary_mode | ||
spec.ConstInt(), # allow_fp16_qk_reduction | ||
spec.ConstFloat(), # rope_scale | ||
spec.ConstFloat(), # rope_theta | ||
spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"), # o | ||
], | ||
ret=spec.Tensor(("qo_len", "num_qo_heads", "head_dim"), "float16"), | ||
), | ||
"FlashInferSingleDecodeWithKVCache": spec.ExternFunctionSpec( | ||
args=[ | ||
spec.Tensor(("num_qo_heads", "head_dim"), "float16"), # q | ||
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # k | ||
spec.Tensor(("kv_dim_0", "kv_dim_1", "head_dim"), "float16"), # v | ||
spec.ConstInt(), # qkv_layout | ||
spec.ConstInt(), # rotary_mode | ||
spec.ConstFloat(), # rope_scale | ||
spec.ConstFloat(), # rope_theta | ||
spec.Tensor(("num_qo_heads", "head_dim"), "float16"), # o | ||
], | ||
ret=spec.Tensor(("num_qo_heads", "head_dim"), "float16"), | ||
), | ||
}, | ||
compile_options=compile_options, | ||
compiler="nvcc", | ||
output_format="obj", | ||
) | ||
|
||
def single_decode_with_kvcache( # pylint: disable=too-many-arguments | ||
self, | ||
q: nn.Tensor, | ||
k: nn.Tensor, | ||
v: nn.Tensor, | ||
qkv_layout: QKVLayout = QKVLayout.NHD, | ||
rotary_mode: RotaryMode = RotaryMode.kNone, | ||
rope_scale: float = 1.0, | ||
rope_theta: float = 1e4, | ||
): | ||
""" | ||
FlashInfer single-batched decode kernel with kvcache. | ||
Parameters | ||
---------- | ||
q : nn.Tensor | ||
The input Q tensor. | ||
k : nn.Tensor | ||
The input K tensor. | ||
v : nn.Tensor | ||
The input V tensor. | ||
qkv_layout : QKVLayout | ||
The layout type of Q, K and V. | ||
rotary_mode: RotaryMode | ||
The rotary mode of rotary embedding. | ||
rope_scale: float | ||
The rotary embedding scale. | ||
rope_theta: float | ||
The rotary embedding theta. | ||
Returns | ||
------- | ||
ret : nn.Tensor | ||
The output tensor of FlashInfer single-batched decode with kvcache. | ||
""" | ||
return self.get_extern_func("FlashInferSingleDecodeWithKVCache")( | ||
q, k, v, int(qkv_layout), int(rotary_mode), rope_scale, rope_theta | ||
) | ||
|
||
def single_prefill_with_kvcache( # pylint: disable=too-many-arguments | ||
self, | ||
q: nn.Tensor, | ||
k: nn.Tensor, | ||
v: nn.Tensor, | ||
causal: bool = True, | ||
qkv_layout: QKVLayout = QKVLayout.NHD, | ||
rotary_mode: RotaryMode = RotaryMode.kNone, | ||
allow_fp16_qk_reduction: bool = False, | ||
rope_scale: float = 1.0, | ||
rope_theta: float = 1e4, | ||
): | ||
""" | ||
FlashInfer single-batched prefill kernel with kvcache. | ||
Parameters | ||
---------- | ||
q : nn.Tensor | ||
The input Q tensor. | ||
k : nn.Tensor | ||
The input K tensor. | ||
v : nn.Tensor | ||
The input V tensor. | ||
causal : bool | ||
If the attention is causal. | ||
qkv_layout : QKVLayout | ||
The layout type of Q, K and V. | ||
rotary_mode: RotaryMode | ||
The rotary mode of rotary embedding. | ||
allow_fp16_qk_reduction : bool | ||
If the Q and K matmul reduction is computed under float16. | ||
rope_scale: float | ||
The rotary embedding scale. | ||
rope_theta: float | ||
The rotary embedding theta. | ||
Returns | ||
------- | ||
ret : nn.Tensor | ||
The output tensor of FlashInfer single-batched prefill with kvcache. | ||
""" | ||
return self.get_extern_func("FlashInferSinglePrefillWithKVCache")( | ||
q, | ||
k, | ||
v, | ||
causal, | ||
int(qkv_layout), | ||
int(rotary_mode), | ||
allow_fp16_qk_reduction, | ||
rope_scale, | ||
rope_theta, | ||
) | ||
|
||
|
||
FlashInfer = _FlashInfer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters