Skip to content

Commit

Permalink
[Pass] PruneRelaxFunc to remove Relax function based on target
Browse files Browse the repository at this point in the history
We recently noticed that when FlashInfer is not built due to
unsupported cuda architecture or platform, running single-sequence
ChatModule will hit VM function initialization error, where the
function is used in `create_flashinfer_paged_kv_cache`, which
won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc
at initialization stage (instead of lazy load). Therefore, even
when the `create_flashinfer_paged_kv_cache` is not invoked, the
PackedFuncs will be looked up. So whenever FlashInfer is not
available, the issue will happen.

This PR adds a compiler pass which removes
`create_flashinfer_paged_kv_cache` (and also other similar functions
that may be introduced in the future) based on the target. This
pass can effectively address the issue.
  • Loading branch information
MasterJH5574 committed Jan 7, 2024
1 parent 5e23900 commit dbb7672
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .fuse_dequantize_transpose import FuseDequantizeTranspose
from .fuse_transpose_matmul import FuseTransposeMatmul
from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc
from .prune_relax_func import PruneRelaxFunc

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,6 +59,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR

@register_pipeline("mlc_llm")
def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
flashinfer: bool,
cublas_gemm: bool,
variable_bounds: Dict[str, int] = None,
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
Expand All @@ -75,7 +77,8 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 0. Add additional information for compilation
# Phase 0. Add additional information for compilation and remove unused Relax func
PruneRelaxFunc(flashinfer=flashinfer),
AttachVariableBounds(variable_bounds),
AttachAdditionalPrimFuncs(additional_tirs),
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
Expand Down
33 changes: 33 additions & 0 deletions python/mlc_chat/compiler_pass/prune_relax_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""A pass that removes undesired Relax function from IRModule based on target."""
import tvm
from tvm import IRModule

from mlc_chat.support.auto_target import detect_cuda_arch_list


@tvm.transform.module_pass(opt_level=0, name="PruneRelaxFunc")
class PruneRelaxFunc: # pylint: disable=too-few-public-methods
"""Removes undesired Relax function from IRModule based on target."""

def __init__(self, flashinfer: bool) -> None:
"""Initializer.
Parameters
----------
flashinfer : bool
A boolean indicating if flashinfer is enabled.
"""
self.flashinfer = flashinfer

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""Entrypoint"""
func_dict = {}
for g_var, func in mod.functions_items():
# Remove "create_flashinfer_paged_kv_cache" for unsupported target
if g_var.name_hint == "create_flashinfer_paged_kv_cache" and not self.flashinfer:
continue
func_dict[g_var] = func
ret_mod = IRModule(func_dict)
if mod.attrs is not None:
ret_mod = ret_mod.with_attrs(mod.attrs)
return ret_mod
1 change: 1 addition & 0 deletions python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int:
args,
pipeline=relax.get_pipeline( # type: ignore
"mlc_llm",
flashinfer=args.opt.flashinfer,
cublas_gemm=args.opt.cublas_gemm,
variable_bounds=variable_bounds,
additional_tirs=additional_tirs,
Expand Down

0 comments on commit dbb7672

Please sign in to comment.