diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py new file mode 100644 index 0000000000..66e3dc17d6 --- /dev/null +++ b/python/mlc_chat/cli/compile.py @@ -0,0 +1,141 @@ +"""Command line entrypoint of compilation.""" +import argparse +import json +import logging +from pathlib import Path +from typing import Union + +from mlc_chat.compiler.compile import compile # pylint: disable=redefined-builtin +from mlc_chat.compiler.model import MODELS + +from ..support.auto_config import detect_config +from ..support.auto_target import detect_target_and_host + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def _parse_config(path: Union[str, Path]) -> Path: + try: + return detect_config(Path(path)) + except ValueError as err: + raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") + + +def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + parent = path.parent + if not parent.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") + return path + + +def _parse_model_type(model_type: str, config: Path) -> str: + if model_type == "auto": + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if "model_type" not in cfg: + raise ValueError( + f"'model_type' not found in: {config}. " + f"Please explicitly specify `--model-type` instead" + ) + model_type = cfg["model_type"] + if model_type not in MODELS: + raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") + return model_type + + +def main(): + """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" + parser = argparse.ArgumentParser("MLC LLM Compiler") + parser.add_argument( + "--config", + type=_parse_config, + required=True, + help="Path to config.json file or to the directory that contains config.json, which is " + "a HuggingFace standard that defines model architecture, for example, " + "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=[ + "q0f16", + "q0f32", + "q3f16_1", + "q3f32_1", + "q4f16_1", + "q4f16_ft", + "q4f32_1", + ], + help="The quantization format. TBD", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help="Model architecture, for example, llama. If not set, it is inferred " + "from the config.json file.", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="The GPU device to compile the model to. If not set, it is inferred from locally " + "available GPUs.", + ) + parser.add_argument( + "--host", + type=str, + default="auto", + choices=[ + "auto", + "arm", + "arm64", + "aarch64", + "x86-64", + ], + help="The host CPU ISA to compile the model to. If not set, it is inferred from the " + "local CPU.", + ) + parser.add_argument( + "--opt", + type=str, + default="", + help="Optimization flags.", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help="The name of the output file. The suffix determines if the output file is a " + "shared library or a static library. Available suffixes: " + "1) Linux: .so (shared), .a (static); " + "2) macOS: .dylib (shared), .a (static); " + "3) Windows: .dll (shared), .lib (static); " + "4) Android, iOS: .tar (static); " + "5) Web: .wasm (web assembly)", + ) + parsed = parser.parse_args() + target, build_func = detect_target_and_host(parsed.device, parsed.host) + parsed.model_type = _parse_model_type(parsed.model_type, parsed.config) + compile( + config=parsed.config, + quantization=parsed.quantization, + model_type=parsed.model_type, + target=target, + opt=parsed.opt, + build_func=build_func, + output=parsed.output, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py new file mode 100644 index 0000000000..387ae7e14c --- /dev/null +++ b/python/mlc_chat/compiler/compile.py @@ -0,0 +1,51 @@ +"""Python entrypoint of compilation.""" +import dataclasses +import logging +from io import StringIO +from pathlib import Path +from typing import Callable + +from tvm import IRModule +from tvm.target import Target + +from ..support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class CompileArgs: + """Arguments to MLC LLM's compiler.""" + + config: Path + quantization: str + model_type: str + target: Target + opt: str + build_func: Callable[[IRModule, "CompileArgs"], None] + output: Path + + +def _echo_args(args: CompileArgs) -> None: + out = StringIO() + print(f"{bold('Compiling with arguments:')}", file=out) + print(f" {bold('--config'):<25} {args.config}", file=out) + print(f" {bold('--quantization'):<25} {args.quantization}", file=out) + print(f" {bold('--model-type'):<25} {args.model_type}", file=out) + print(f" {bold('--target'):<25} {args.target.export()}", file=out) + print(f" {bold('--opt'):<25} {args.opt}", file=out) + print(f" {bold('--output'):<25} {args.output}", file=out) + print(out.getvalue().rstrip()) + + +def compile( # pylint: disable=too-many-arguments,redefined-builtin + config: Path, + quantization, + model_type: str, + target: Target, + opt, + build_func: Callable[[IRModule, CompileArgs], None], + output: Path, +): + args = CompileArgs(config, quantization, model_type, target, opt, build_func, output) + _echo_args(args) diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py index b568bd84f7..8bb4879e7d 100644 --- a/python/mlc_chat/compiler/model/__init__.py +++ b/python/mlc_chat/compiler/model/__init__.py @@ -1,2 +1,2 @@ """Model definition for the compiler.""" -from . import llama, llama_config, llama_parameter +from .model import MODELS, Model diff --git a/python/mlc_chat/compiler/model/llama.py b/python/mlc_chat/compiler/model/llama_model.py similarity index 98% rename from python/mlc_chat/compiler/model/llama.py rename to python/mlc_chat/compiler/model/llama_model.py index 663e6d93c2..49e947f741 100644 --- a/python/mlc_chat/compiler/model/llama.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -156,11 +156,11 @@ def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor class LlamaForCasualLM(nn.Module): - def __init__(self, config: LlamaConfig, dtype: str = "float32"): + def __init__(self, config: LlamaConfig): self.model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vocab_size = config.vocab_size - self.dtype = dtype + self.dtype = "float32" def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) diff --git a/python/mlc_chat/compiler/model/llama_parameter.py b/python/mlc_chat/compiler/model/llama_parameter.py index b0fa867130..4c68fdc899 100644 --- a/python/mlc_chat/compiler/model/llama_parameter.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -7,10 +7,11 @@ import numpy as np from ..parameter import ExternMapping -from .llama import LlamaConfig, LlamaForCasualLM +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM -def hf_torch(model_config: LlamaConfig) -> ExternMapping: +def huggingface(model_config: LlamaConfig, _) -> ExternMapping: """Returns a parameter mapping that maps from the names of MLC LLM parameters to the names of HuggingFace PyTorch parameters. diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py new file mode 100644 index 0000000000..6f36de2ba2 --- /dev/null +++ b/python/mlc_chat/compiler/model/model.py @@ -0,0 +1,37 @@ +"""A centralized registry of all existing model architures and their configurations.""" +import dataclasses +from pathlib import Path +from typing import Any, Callable, Dict, Optional + +from tvm.relax.frontend import nn + +from ..parameter import ExternMapping, QuantizeMapping +from . import llama_config, llama_model, llama_parameter + +ModelConfig = Any +QuantizeConfig = Any + +LoaderType = Callable[[ModelConfig, QuantizeConfig], ExternMapping] +QuantizerType = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] + + +@dataclasses.dataclass +class Model: + """All about a model architecture: its configuration, its parameter loader and quantization.""" + + model: Callable[[ModelConfig], nn.Module] + config: Callable[[Path], ModelConfig] + source_loader_huggingface: Optional[LoaderType] = None + source_loader_awq: Optional[LoaderType] = None + quantizer_group_quant: Optional[QuantizerType] = None + + +MODELS: Dict[str, Model] = { + "llama": Model( + model=llama_model.LlamaForCasualLM, + config=llama_config.LlamaConfig.from_file, + source_loader_huggingface=llama_parameter.huggingface, + source_loader_awq=None, + quantizer_group_quant=None, + ) +} diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 1a4d9bf765..a31515f41e 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -2,8 +2,12 @@ import logging from pathlib import Path +from .style import green, red + logger = logging.getLogger(__name__) +FOUND = green("Found") + def detect_config(config_path: Path) -> Path: """Detect and return the path that points to config.json. If config_path is a directory, @@ -30,5 +34,5 @@ def detect_config(config_path: Path) -> Path: else: config_json_path = config_path - logger.info("Found config.json: %s", config_json_path) + logger.info("%s model configuration: %s", FOUND, config_json_path) return config_json_path diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py new file mode 100644 index 0000000000..200596b33e --- /dev/null +++ b/python/mlc_chat/support/auto_target.py @@ -0,0 +1,282 @@ +"""Helper functioms for target auto-detection.""" +import logging +from typing import TYPE_CHECKING, Callable, Optional, Tuple + +from tvm import IRModule, relax +from tvm._ffi import register_func +from tvm.contrib import tar, xcode +from tvm.target import Target + +from .style import green, red + +if TYPE_CHECKING: + from mlc_chat.compiler.compile import CompileArgs + + +logger = logging.getLogger(__name__) + +# TODO: add help message on how to specify the target manually # pylint: disable=fixme +# TODO: revisit system_lib_prefix handling # pylint: disable=fixme +# TODO: include host detection logic below after the new TVM build is done. # pylint: disable=fixme +HELP_MSG = """TBD""" +FOUND = green("Found") +NOT_FOUND = red("Not found") +BuildFunc = Callable[[IRModule, "CompileArgs"], None] + + +def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, BuildFunc]: + target, build_func = detect_target_gpu(target_hint) + if target.host is None: + target = Target(target, host=detect_target_host(host_hint)) + return target, build_func + + +def detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: + if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: + hint += ":generic" + if hint == "auto": + logger.info("Detecting potential target devices: %s", ", ".join(AUTO_DETECT_DEVICES)) + target: Optional[Target] = None + for device in AUTO_DETECT_DEVICES: + device_target = _detect_target_from_device(device + ":0") + if device_target is not None and target is None: + target = device_target + if target is None: + raise ValueError("No GPU target detected. Please specify explicitly") + return target, _build_default() + if hint in AUTO_DETECT_DEVICES: + target = _detect_target_from_device(hint + ":0") + if target is None: + raise ValueError(f"No GPU target detected from device: {hint}") + return target, _build_default() + if hint in PRESET: + preset = PRESET[hint] + target = Target(preset["target"]) # type: ignore[index] + build = preset.get("build", _build_default) # type: ignore[attr-defined] + return target, build() + if _is_device(hint): + logger.info("Detecting target device: %s", hint) + target = Target.from_device(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + try: + logger.info("Try creating device target from string: %s", hint) + target = Target(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + except Exception as err: + logger.info("%s: Failed to create target", NOT_FOUND) + raise ValueError(f"Invalid target: {hint}") from err + + +def detect_target_host(hint: str) -> Target: + """Detect the host CPU architecture.""" + # cpu = codegen.llvm_get_system_cpu() + # triple = codegen.llvm_get_system_triple() + # vendor = codegen.llvm_get_system_x86_vendor() + if hint == "auto": + hint = "x86-64" + if hint == "x86-64": + hint = "x86_64" + return Target({"kind": "llvm", "mtriple": f"{hint}-unknown-unknown"}) + + +def _is_device(device: str): + if " " in device: + return False + if device.count(":") != 1: + return False + return True + + +def _detect_target_from_device(device: str) -> Optional[Target]: + try: + target = Target.from_device(device) + except ValueError: + logger.info("%s: target device: %s", NOT_FOUND, device) + return None + logger.info( + '%s configuration of target device "%s": %s', + FOUND, + device, + target.export(), + ) + return target + + +def _build_metal_x86_64(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + assert output.suffix == ".dylib" + relax.build( + mod, + target=args.target, + ).export_library( + str(output), + fcompile=xcode.create_dylib, + sdk="macosx", + arch="x86_64", + ) + + return build + + +def _build_iphone(): + @register_func("tvm_callback_metal_compile", override=True) + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + assert output.suffix == ".tar" + relax.build( + mod.with_attr("system_lib_prefix", system_lib_prefix), + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_android(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + assert output.suffix == ".tar" + relax.build( + mod.with_attr("system_lib_prefix", system_lib_prefix), + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_webgpu(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + assert output.suffix == ".wasm" + relax.build( + mod, + target=args.target, + system_lib=True, + ).export_library( + str(output), + ) + + return build + + +def _build_default(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + if output.suffix in [".a", ".lib"]: + system_lib = True + elif output.suffix in [".so", ".dylib", ".dll"]: + system_lib = False + else: + logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix) + system_lib = False + relax.build( + mod, + target=args.target, + system_lib=system_lib, + ).export_library( + str(output), + ) + + return build + + +AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"] + +PRESET = { + "iphone:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + "host": { + "kind": "llvm", + "mtriple": "arm64-apple-darwin", + }, + }, + "build": _build_iphone, + }, + "android:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "metal:x86-64": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + "build": _build_metal_x86_64, + }, + "webgpu:generic": { + "target": { + "kind": "webgpu", + "host": { + "kind": "llvm", + "mtriple": "wasm32-unknown-unknown-wasm", + }, + }, + "build": _build_webgpu, + }, + "opencl:generic": { + "target": { + "kind": "opencl", + }, + }, + "mali:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-gnu", + }, + }, + }, + "metal:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + }, + "vulkan:generic": { + "target": { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1, + }, + }, +} diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index 74e8a8b8c0..b19ec6b07a 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -8,7 +8,9 @@ def detect_weight( - weight_path: Path, config_json_path: Path, weight_format: str = "auto" + weight_path: Path, + config_json_path: Path, + weight_format: str = "auto", ) -> Tuple[Path, str]: """Detect the weight directory, and detect the weight format. diff --git a/python/mlc_chat/support/style.py b/python/mlc_chat/support/style.py new file mode 100644 index 0000000000..5b2272e1a0 --- /dev/null +++ b/python/mlc_chat/support/style.py @@ -0,0 +1,62 @@ +"""Printing styles.""" + +from enum import Enum + + +class Styles(Enum): + """Predefined set of styles to be used. + + Reference: + - https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit + - https://stackoverflow.com/a/17303428 + """ + + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +def red(text: str) -> str: + """Return red text.""" + return f"{Styles.RED.value}{text}{Styles.END.value}" + + +def green(text: str) -> str: + """Return green text.""" + return f"{Styles.GREEN.value}{text}{Styles.END.value}" + + +def yellow(text: str) -> str: + """Return yellow text.""" + return f"{Styles.YELLOW.value}{text}{Styles.END.value}" + + +def blue(text: str) -> str: + """Return blue text.""" + return f"{Styles.BLUE.value}{text}{Styles.END.value}" + + +def purple(text: str) -> str: + """Return purple text.""" + return f"{Styles.PURPLE.value}{text}{Styles.END.value}" + + +def cyan(text: str) -> str: + """Return cyan text.""" + return f"{Styles.CYAN.value}{text}{Styles.END.value}" + + +def bold(text: str) -> str: + """Return bold text.""" + return f"{Styles.BOLD.value}{text}{Styles.END.value}" + + +def underline(text: str) -> str: + """Return underlined text.""" + return f"{Styles.UNDERLINE.value}{text}{Styles.END.value}"