Skip to content

Commit

Permalink
Add CLI commands for compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Oct 24, 2023
1 parent 7ae8c6d commit 25b8dab
Show file tree
Hide file tree
Showing 10 changed files with 587 additions and 7 deletions.
141 changes: 141 additions & 0 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Command line entrypoint of compilation."""
import argparse
import json
import logging
from pathlib import Path
from typing import Union

from mlc_chat.compiler.compile import compile # pylint: disable=redefined-builtin
from mlc_chat.compiler.model import MODELS

from ..support.auto_config import detect_config
from ..support.auto_target import detect_target_and_host

logging.basicConfig(
level=logging.DEBUG,
style="{",
datefmt="%Y-%m-%d %H:%M:%S",
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
)


def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(Path(path))
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")


def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
parent = path.parent
if not parent.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
return path


def _parse_model_type(model_type: str, config: Path) -> str:
if model_type == "auto":
with open(config, "r", encoding="utf-8") as config_file:
cfg = json.load(config_file)
if "model_type" not in cfg:
raise ValueError(
f"'model_type' not found in: {config}. "
f"Please explicitly specify `--model-type` instead"
)
model_type = cfg["model_type"]
if model_type not in MODELS:
raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}")
return model_type


def main():
"""Parse command line argumennts and call `mlc_llm.compiler.compile`."""
parser = argparse.ArgumentParser("MLC LLM Compiler")
parser.add_argument(
"--config",
type=_parse_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=[
"q0f16",
"q0f32",
"q3f16_1",
"q3f32_1",
"q4f16_1",
"q4f16_ft",
"q4f32_1",
],
help="The quantization format. TBD",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file.",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="The GPU device to compile the model to. If not set, it is inferred from locally "
"available GPUs.",
)
parser.add_argument(
"--host",
type=str,
default="auto",
choices=[
"auto",
"arm",
"arm64",
"aarch64",
"x86-64",
],
help="The host CPU ISA to compile the model to. If not set, it is inferred from the "
"local CPU.",
)
parser.add_argument(
"--opt",
type=str,
default="",
help="Optimization flags.",
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The name of the output file. The suffix determines if the output file is a "
"shared library or a static library. Available suffixes: "
"1) Linux: .so (shared), .a (static); "
"2) macOS: .dylib (shared), .a (static); "
"3) Windows: .dll (shared), .lib (static); "
"4) Android, iOS: .tar (static); "
"5) Web: .wasm (web assembly)",
)
parsed = parser.parse_args()
target, build_func = detect_target_and_host(parsed.device, parsed.host)
parsed.model_type = _parse_model_type(parsed.model_type, parsed.config)
compile(
config=parsed.config,
quantization=parsed.quantization,
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
build_func=build_func,
output=parsed.output,
)


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Python entrypoint of compilation."""
import dataclasses
import logging
from io import StringIO
from pathlib import Path
from typing import Callable

from tvm import IRModule
from tvm.target import Target

from ..support.style import bold

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class CompileArgs:
"""Arguments to MLC LLM's compiler."""

config: Path
quantization: str
model_type: str
target: Target
opt: str
build_func: Callable[[IRModule, "CompileArgs"], None]
output: Path


def _echo_args(args: CompileArgs) -> None:
out = StringIO()
print(f"{bold('Compiling with arguments:')}", file=out)
print(f" {bold('--config'):<25} {args.config}", file=out)
print(f" {bold('--quantization'):<25} {args.quantization}", file=out)
print(f" {bold('--model-type'):<25} {args.model_type}", file=out)
print(f" {bold('--target'):<25} {args.target.export()}", file=out)
print(f" {bold('--opt'):<25} {args.opt}", file=out)
print(f" {bold('--output'):<25} {args.output}", file=out)
print(out.getvalue().rstrip())


def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Path,
quantization,
model_type: str,
target: Target,
opt,
build_func: Callable[[IRModule, CompileArgs], None],
output: Path,
):
args = CompileArgs(config, quantization, model_type, target, opt, build_func, output)
_echo_args(args)
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Model definition for the compiler."""
from . import llama, llama_config, llama_parameter
from .model import MODELS, Model
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor


class LlamaForCasualLM(nn.Module):
def __init__(self, config: LlamaConfig, dtype: str = "float32"):
def __init__(self, config: LlamaConfig):
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.vocab_size = config.vocab_size
self.dtype = dtype
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
Expand Down
5 changes: 3 additions & 2 deletions python/mlc_chat/compiler/model/llama_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np

from ..parameter import ExternMapping
from .llama import LlamaConfig, LlamaForCasualLM
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM


def hf_torch(model_config: LlamaConfig) -> ExternMapping:
def huggingface(model_config: LlamaConfig, _) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Expand Down
37 changes: 37 additions & 0 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""A centralized registry of all existing model architures and their configurations."""
import dataclasses
from pathlib import Path
from typing import Any, Callable, Dict, Optional

from tvm.relax.frontend import nn

from ..parameter import ExternMapping, QuantizeMapping
from . import llama_config, llama_model, llama_parameter

ModelConfig = Any
QuantizeConfig = Any

LoaderType = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
QuantizerType = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]


@dataclasses.dataclass
class Model:
"""All about a model architecture: its configuration, its parameter loader and quantization."""

model: Callable[[ModelConfig], nn.Module]
config: Callable[[Path], ModelConfig]
source_loader_huggingface: Optional[LoaderType] = None
source_loader_awq: Optional[LoaderType] = None
quantizer_group_quant: Optional[QuantizerType] = None


MODELS: Dict[str, Model] = {
"llama": Model(
model=llama_model.LlamaForCasualLM,
config=llama_config.LlamaConfig.from_file,
source_loader_huggingface=llama_parameter.huggingface,
source_loader_awq=None,
quantizer_group_quant=None,
)
}
6 changes: 5 additions & 1 deletion python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import logging
from pathlib import Path

from .style import green, red

logger = logging.getLogger(__name__)

FOUND = green("Found")


def detect_config(config_path: Path) -> Path:
"""Detect and return the path that points to config.json. If config_path is a directory,
Expand All @@ -30,5 +34,5 @@ def detect_config(config_path: Path) -> Path:
else:
config_json_path = config_path

logger.info("Found config.json: %s", config_json_path)
logger.info("%s model configuration: %s", FOUND, config_json_path)
return config_json_path
Loading

0 comments on commit 25b8dab

Please sign in to comment.