From 04587f70ed3879b7f5a40fdab35630f86ccfe514 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 16 Feb 2024 20:36:38 -0500 Subject: [PATCH] [Pass] Scatter centralized TupleGetItems This PR introduces a pass that scatters the TupleGetItems of packed parameters centered at the beginning of Relax functions. This is because currently doing TupleGetItems in runtime for hundreds of times will cause significant CPU delay, which can be hidden when these operations are scattered, so that they can run in parallel with GPU kernels. --- python/mlc_chat/compiler_pass/pipeline.py | 2 + .../compiler_pass/scatter_tuple_get_item.py | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 python/mlc_chat/compiler_pass/scatter_tuple_get_item.py diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 02a44eb1be..6a5420afa1 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -27,6 +27,7 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .rewrite_kv_cache_creation import RewriteKVCacheCreation +from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -132,6 +133,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I if target.kind.name != "cuda" else tvm.transform.Sequential([]) ), + ScatterTupleGetItem(), tvm.relax.transform.RewriteDataflowReshape(), tvm.relax.transform.ToNonDataflow(), tvm.relax.transform.RemovePurityChecking(), diff --git a/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py b/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py new file mode 100644 index 0000000000..281c6ec9a3 --- /dev/null +++ b/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py @@ -0,0 +1,51 @@ +"""A compiler pass that scatters TupleGetItem for lazy TupleGetItems.""" + +from typing import Dict + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr import Expr, Var +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="ScatterTupleGetItem") +class ScatterTupleGetItem: # pylint: disable=too-few-public-methods + """A compiler pass that scatters TupleGetItem for lazy TupleGetItems.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Scatter(mod).transform() + + +@mutator +class _Scatter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule) -> None: + super().__init__(mod) + self.mod = mod + self.var_map: Dict[Var, Expr] = {} + + def transform(self) -> IRModule: + """Entry point""" + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_var_binding_(self, binding: relax.VarBinding): + super().visit_var_binding_(binding) + if isinstance(binding.value, relax.TupleGetItem): + self.var_map[binding.var] = binding.value + + def visit_dataflow_var_( # pylint: disable=arguments-renamed + self, var: relax.DataflowVar + ) -> Expr: + if var in self.var_map: + new_var = self.builder_.emit(self.var_map[var], name_hint=var.name_hint) + self.set_var_remap(var.vid, new_var) + self.var_map.pop(var) + return new_var + return var