From 9d35c77e91b522405bcd1fe7e7c4738bc20f0159 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 7 Jan 2024 12:00:05 -0500 Subject: [PATCH] [Pass] PruneRelaxFunc to remove Relax function based on target We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue. --- python/mlc_chat/compiler_pass/pipeline.py | 5 ++- .../compiler_pass/prune_relax_func.py | 31 +++++++++++++++++++ python/mlc_chat/interface/compile.py | 1 + 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 python/mlc_chat/compiler_pass/prune_relax_func.py diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 81a79b0a21..83a69757ae 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -20,6 +20,7 @@ from .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc +from .prune_relax_func import PruneRelaxFunc logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR @register_pipeline("mlc_llm") def _mlc_llm_pipeline( # pylint: disable=too-many-arguments + flashinfer: bool = False, cublas_gemm: bool = False, variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, @@ -75,7 +77,8 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: seq = tvm.transform.Sequential( [ - # Phase 0. Add additional information for compilation + # Phase 0. Add additional information for compilation and remove unused Relax func + PruneRelaxFunc(flashinfer=flashinfer), AttachVariableBounds(variable_bounds), AttachAdditionalPrimFuncs(additional_tirs), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), diff --git a/python/mlc_chat/compiler_pass/prune_relax_func.py b/python/mlc_chat/compiler_pass/prune_relax_func.py new file mode 100644 index 0000000000..5271140786 --- /dev/null +++ b/python/mlc_chat/compiler_pass/prune_relax_func.py @@ -0,0 +1,31 @@ +"""A pass that removes undesired Relax function from IRModule based on target.""" +import tvm +from tvm import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="PruneRelaxFunc") +class PruneRelaxFunc: # pylint: disable=too-few-public-methods + """Removes undesired Relax function from IRModule based on target.""" + + def __init__(self, flashinfer: bool) -> None: + """Initializer. + + Parameters + ---------- + flashinfer : bool + A boolean indicating if flashinfer is enabled. + """ + self.flashinfer = flashinfer + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + func_dict = {} + for g_var, func in mod.functions_items(): + # Remove "create_flashinfer_paged_kv_cache" for unsupported target + if g_var.name_hint == "create_flashinfer_paged_kv_cache" and not self.flashinfer: + continue + func_dict[g_var] = func + ret_mod = IRModule(func_dict) + if mod.attrs is not None: + ret_mod = ret_mod.with_attrs(mod.attrs) + return ret_mod diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py index a483a055e7..9627a09ca2 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -166,6 +166,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: args, pipeline=relax.get_pipeline( # type: ignore "mlc_llm", + flashinfer=args.opt.flashinfer, cublas_gemm=args.opt.cublas_gemm, variable_bounds=variable_bounds, additional_tirs=additional_tirs,