Skip to content

Commit

Permalink
Misc Cleanups of Compilation Pipeline (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Oct 31, 2023
1 parent b5bfa5b commit 8438b27
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 34 deletions.
2 changes: 1 addition & 1 deletion python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..support.auto_target import detect_target_and_host

logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
style="{",
datefmt="%Y-%m-%d %H:%M:%S",
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
Expand Down
8 changes: 8 additions & 0 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Python entrypoint of compilation."""
import dataclasses
import logging
from io import StringIO
from pathlib import Path
from typing import Callable
Expand All @@ -12,6 +13,8 @@
from .model import Model
from .quantization import Quantization

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class CompileArgs: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -40,15 +43,20 @@ def _echo_args(args: CompileArgs) -> None:


def _compile(args: CompileArgs):
logger.info("Creating model from: %s", args.config)
model_config = args.model.config.from_file(args.config)
quantization = args.quantization
model, _ = args.model.quantize[quantization.kind](model_config, quantization)
logger.info("Exporting the model to TVM Unity compiler")
mod, _named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
logger.info("Running optimizations using TVM Unity")
with args.target:
mod = relax.get_pipeline("mlc_llm")(mod)
logger.info("Generating code using TVM Unity")
args.build_func(mod, args)
logger.info("Code dumped to: %s", args.output)


def compile( # pylint: disable=too-many-arguments,redefined-builtin
Expand Down
3 changes: 1 addition & 2 deletions python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def transform_module(
_ctx: tvm.transform.PassContext,
) -> IRModule:
"""IRModule-level transformation"""
for g_var in list(mod.functions):
func = mod[g_var]
for g_var, func in mod.functions_items():
changed = False
for attr in self.attrs:
if func.attrs is not None and attr in func.attrs:
Expand Down
23 changes: 13 additions & 10 deletions python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ def transform_module(
_ctx: tvm.transform.PassContext,
) -> IRModule:
"""IRModule-level transformation"""
seq = []
for n_aux_tensor in [1, 2, 3, 4]:
for match_ewise in [0, 1, 2, 6]:
if match_ewise == 6 and n_aux_tensor != 4:
continue
mod = relax.transform.FuseOpsByPattern(
[
(
"decode_matmul",
*_pattern(match_ewise, n_aux_tensor),
)
]
)(mod)
mod = relax.transform.FuseTIR()(mod)
return mod
seq.append(
relax.transform.FuseOpsByPattern(
[
(
"decode_matmul",
*_pattern(match_ewise, n_aux_tensor),
)
]
)
)
seq.append(relax.transform.FuseTIR())
return tvm.transform.Sequential(seq)(mod)


def _pattern(match_ewise: int, n_aux_tensor: int):
Expand Down
24 changes: 14 additions & 10 deletions python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ def transform_module(
_ctx: tvm.transform.PassContext,
) -> IRModule:
"""IRModule-level transformation"""
seq = []
for n_aux_tensor in [2, 3]:
for match_tir_vars in [False, True]:
mod = relax.transform.FuseOpsByPattern(
[
(
"decode_take",
*_pattern(n_aux_tensor, match_tir_vars),
)
]
)(mod)
mod = relax.transform.FuseTIR()(mod)
for g_var, func in mod.functions.items():
seq.append(
relax.transform.FuseOpsByPattern(
[
(
"decode_take",
*_pattern(n_aux_tensor, match_tir_vars),
)
]
)
)
seq.append(relax.transform.FuseTIR())
mod = tvm.transform.Sequential(seq)(mod)
for g_var, func in mod.functions_items():
name = g_var.name_hint
if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)):
mod = tvm.IRModule({"main": func})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(

def transform(self) -> IRModule:
"""Entry point"""
for g_var, func in self.mod.functions.items():
for g_var, func in self.mod.functions_items():
if isinstance(func, relax.Function):
updated_func = self.visit_expr(func)
updated_func = remove_all_unused(updated_func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
),
]
)(mod)

transpose_matmul_codegen = _TransposeMatmulFuser(mod)
for g_var in mod.functions:
func = mod[g_var]
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
func = transpose_matmul_codegen.visit_expr(func)
transpose_matmul_codegen.builder_.update_func(g_var, func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@ def __init__(self, mod: IRModule):

def transform(self) -> IRModule:
"""Entry point of the transformation"""
for g_var, func in self.mod.functions.items():
for g_var, func in self.mod.functions_items():
if isinstance(func, tir.PrimFunc):
updated_func, tensor_sinfo_list = remove_global_buf_alloc(func)
if len(tensor_sinfo_list) > 0:
self.gv2new_tensor_sinfo[g_var] = (tensor_sinfo_list, func)
self.builder_.update_func(g_var, updated_func)

self.mod = self.builder_.get()
for g_var, func in self.mod.functions.items():
if not isinstance(func, relax.Function):
continue
updated_func = self.visit_expr(func)
updated_func = remove_all_unused(updated_func)
self.builder_.update_func(g_var, updated_func)
for g_var, func in self.mod.functions_items():
if isinstance(func, relax.Function):
updated_func = self.visit_expr(func)
updated_func = remove_all_unused(updated_func)
self.builder_.update_func(g_var, updated_func)
return self.builder_.get()

def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed
Expand Down
23 changes: 23 additions & 0 deletions python/mlc_chat/compiler/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""The compilation pipeline for LLM applications."""
import logging

import tvm
from tvm import IRModule
from tvm import dlight as dl
from tvm.relax import register_pipeline # pylint: disable=no-name-in-module

Expand All @@ -10,6 +13,21 @@
from .fuse_transpose_matmul import FuseTransposeMatmul
from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc

logger = logging.getLogger(__name__)


@tvm.transform.module_pass(opt_level=0, name="_LogProgress")
class _LogProgress: # pylint: disable=too-few-public-methods
"""A dummy compiler pass that does nothing but logging."""

def __init__(self, *args):
self.args = args

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""A dummy transformation"""
logger.info(*self.args)
return mod


@register_pipeline("mlc_llm")
def _mlc_llm_pipeline():
Expand All @@ -18,27 +36,32 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
seq = tvm.transform.Sequential(
[
# Phase 1. Passes on high-level operator graph
_LogProgress("Running TVM Relax graph-level optimizations"),
FuseDecodeTranspose(skip_gemm=False),
FuseTransposeMatmul(),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
_LogProgress("Lowering to TVM TIR kernels"),
tvm.relax.transform.LegalizeOps(),
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FoldConstant(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
# Phase 3. Passes on TIR
_LogProgress("Running TVM TIR-level optimizations"),
FuseDecodeMatmulEwise(),
FuseDecodeTake(),
tvm.relax.transform.DeadCodeElimination(),
CleanUpTIRAttrs(["op_pattern"]),
# Phase 4. Low-level Optimizations
_LogProgress("Running TVM Dlight low-level optimizations"),
dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
),
_LogProgress("Running memory optimizations"),
LiftTIRGlobalBufferAlloc(),
tvm.tir.transform.ForceNarrowIndexToInt32(),
]
Expand Down

0 comments on commit 8438b27

Please sign in to comment.