diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 81a79b0a21..83a69757ae 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -20,6 +20,7 @@ from .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc +from .prune_relax_func import PruneRelaxFunc logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR @register_pipeline("mlc_llm") def _mlc_llm_pipeline( # pylint: disable=too-many-arguments + flashinfer: bool = False, cublas_gemm: bool = False, variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, @@ -75,7 +77,8 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: seq = tvm.transform.Sequential( [ - # Phase 0. Add additional information for compilation + # Phase 0. Add additional information for compilation and remove unused Relax func + PruneRelaxFunc(flashinfer=flashinfer), AttachVariableBounds(variable_bounds), AttachAdditionalPrimFuncs(additional_tirs), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), diff --git a/python/mlc_chat/compiler_pass/prune_relax_func.py b/python/mlc_chat/compiler_pass/prune_relax_func.py new file mode 100644 index 0000000000..5271140786 --- /dev/null +++ b/python/mlc_chat/compiler_pass/prune_relax_func.py @@ -0,0 +1,31 @@ +"""A pass that removes undesired Relax function from IRModule based on target.""" +import tvm +from tvm import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="PruneRelaxFunc") +class PruneRelaxFunc: # pylint: disable=too-few-public-methods + """Removes undesired Relax function from IRModule based on target.""" + + def __init__(self, flashinfer: bool) -> None: + """Initializer. + + Parameters + ---------- + flashinfer : bool + A boolean indicating if flashinfer is enabled. + """ + self.flashinfer = flashinfer + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + func_dict = {} + for g_var, func in mod.functions_items(): + # Remove "create_flashinfer_paged_kv_cache" for unsupported target + if g_var.name_hint == "create_flashinfer_paged_kv_cache" and not self.flashinfer: + continue + func_dict[g_var] = func + ret_mod = IRModule(func_dict) + if mod.attrs is not None: + ret_mod = ret_mod.with_attrs(mod.attrs) + return ret_mod diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py index a483a055e7..9627a09ca2 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -166,6 +166,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: args, pipeline=relax.get_pipeline( # type: ignore "mlc_llm", + flashinfer=args.opt.flashinfer, cublas_gemm=args.opt.cublas_gemm, variable_bounds=variable_bounds, additional_tirs=additional_tirs,