diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 845f7e431e..60384090de 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -7,13 +7,13 @@ from mlc_chat.compiler import ( # pylint: disable=redefined-builtin HELP, MODELS, - QUANTIZATION, + ModelConfigOverride, OptimizationFlags, compile, ) from ..support.argparse import ArgumentParser -from ..support.auto_config import detect_config, detect_model_type +from ..support.auto_config import detect_mlc_chat_config, detect_model_type from ..support.auto_target import detect_target_and_host @@ -39,18 +39,11 @@ def _check_system_lib_prefix(prefix: str) -> str: parser = ArgumentParser("MLC LLM Compiler") parser.add_argument( "--model", - type=detect_config, + type=detect_mlc_chat_config, required=True, dest="config", help=HELP["model"] + " (required)", ) - parser.add_argument( - "--quantization", - type=str, - required=True, - choices=list(QUANTIZATION.keys()), - help=HELP["quantization"] + " (required, choices: %(choices)s)", - ) parser.add_argument( "--model-type", type=str, @@ -82,12 +75,6 @@ def _check_system_lib_prefix(prefix: str) -> str: default="", help=HELP["system_lib_prefix"] + ' (default: "%(default)s")', ) - parser.add_argument( - "--context-window-size", - type=int, - default=None, - help=HELP["context_window_size"] + ' (default: "%(default)s")', - ) parser.add_argument( "--output", "-o", @@ -96,30 +83,21 @@ def _check_system_lib_prefix(prefix: str) -> str: help=HELP["output_compile"] + " (required)", ) parser.add_argument( - "--sliding-window", - type=int, - default=None, - help=HELP["sliding_window"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--prefill-chunk-size", - type=int, - default=None, - help=HELP["prefill_chunk_size"] + ' (default: "%(default)s")', + "--overrides", + type=ModelConfigOverride.from_str, + default="", + help=HELP["overrides"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) target, build_func = detect_target_and_host(parsed.device, parsed.host) parsed.model_type = detect_model_type(parsed.model_type, parsed.config) compile( config=parsed.config, - quantization=QUANTIZATION[parsed.quantization], model_type=parsed.model_type, target=target, opt=parsed.opt, build_func=build_func, system_lib_prefix=parsed.system_lib_prefix, output=parsed.output, - context_window_size=parsed.context_window_size, - sliding_window=parsed.sliding_window, - prefill_chunk_size=parsed.prefill_chunk_size, + overrides=parsed.overrides, ) diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 2cfe39556c..86eb382ee0 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -4,7 +4,7 @@ import logging from io import StringIO from pathlib import Path -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Tuple from tvm import IRModule, relax from tvm.relax.frontend import nn @@ -15,7 +15,7 @@ from .flags_model_config_override import ModelConfigOverride from .flags_optimization import OptimizationFlags from .model import Model -from .quantization import Quantization +from .quantization import QUANTIZATION, Quantization logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def display(self) -> None: print(f" {bold('--opt'):<25} {self.opt}", file=out) print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out) print(f" {bold('--output'):<25} {self.output}", file=out) - print(f" {bold('--overrides'):<25} {dataclasses.asdict(self.overrides)}", file=out) + print(f" {bold('--overrides'):<25} {self.overrides}", file=out) print(out.getvalue().rstrip()) @@ -106,9 +106,8 @@ def _attach_variable_bounds(mod, model_config): mod[g_var] = func.with_attr("tir_var_upper_bound", tir_bound_map) -def _compile(args: CompileArgs): +def _compile(args: CompileArgs, model_config: ConfigBase): logger.info("Creating model from: %s", args.config) - model_config = args.model.config.from_file(args.config) args.overrides.apply(model_config) model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization) logger.info("Exporting the model to TVM Unity compiler") @@ -127,18 +126,17 @@ def _compile(args: CompileArgs): def compile( # pylint: disable=too-many-arguments,redefined-builtin config: Path, - quantization: Quantization, model_type: Model, target: Target, opt: OptimizationFlags, build_func: Callable[[IRModule, CompileArgs], None], system_lib_prefix: str, output: Path, - context_window_size: Optional[int], - sliding_window: Optional[int], - prefill_chunk_size: Optional[int], + overrides: ModelConfigOverride, ): """Compile a model given its configuration and quantization format to a specific target.""" + model_config = model_type.config.from_file(config) + quantization = QUANTIZATION[model_config.kwargs["quantization"]] args = CompileArgs( config, quantization, @@ -148,11 +146,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin build_func, system_lib_prefix, output, - ModelConfigOverride( - context_window_size=context_window_size, - sliding_window=sliding_window, - prefill_chunk_size=prefill_chunk_size, - ), + overrides, ) args.display() - _compile(args) + _compile(args, model_config) diff --git a/python/mlc_chat/compiler/flags_model_config_override.py b/python/mlc_chat/compiler/flags_model_config_override.py index e37decf736..8365488179 100644 --- a/python/mlc_chat/compiler/flags_model_config_override.py +++ b/python/mlc_chat/compiler/flags_model_config_override.py @@ -1,6 +1,8 @@ """Flags for overriding model config.""" +import argparse import dataclasses import logging +from io import StringIO from typing import Optional from ..support.style import bold @@ -13,10 +15,19 @@ class ModelConfigOverride: """Flags for overriding model config.""" context_window_size: Optional[int] = None + prefill_chunk_size: Optional[int] = None + sliding_window: Optional[int] = None max_batch_size: Optional[int] = None num_shards: Optional[int] = None - sliding_window: Optional[int] = None - prefill_chunk_size: Optional[int] = None + + def __repr__(self) -> str: + out = StringIO() + print(f"context_window_size={self.context_window_size}", file=out, end="") + print(f";prefill_chunk_size={self.prefill_chunk_size}", file=out, end="") + print(f";sliding_window={self.sliding_window}", file=out, end="") + print(f";max_batch_size={self.max_batch_size}", file=out, end="") + print(f";num_shards={self.num_shards}", file=out, end="") + return out.getvalue().rstrip() def apply(self, model_config): """Apply the overrides to the given model config.""" @@ -28,12 +39,14 @@ def apply(self, model_config): self.context_window_size, ) model_config.context_window_size = self.context_window_size - if self.max_batch_size is not None: - model_config.max_batch_size = self.max_batch_size - if self.num_shards is not None: - model_config.num_shards = self.num_shards - - # Handle sliding window and sliding window chunk size + if self.prefill_chunk_size is not None: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + model_config.prefill_chunk_size, + self.prefill_chunk_size, + ) + model_config.prefill_chunk_size = self.prefill_chunk_size if self.sliding_window is not None: logger.info( "Overriding %s from %d to %d", @@ -50,11 +63,38 @@ def apply(self, model_config): model_config.sliding_window, ) model_config.prefill_chunk_size = self.sliding_window - if self.prefill_chunk_size is not None: + if self.max_batch_size is not None: logger.info( "Overriding %s from %d to %d", - bold("prefill_chunk_size"), - model_config.prefill_chunk_size, - self.prefill_chunk_size, + bold("max_batch_size"), + model_config.max_batch_size, + self.max_batch_size, ) - model_config.prefill_chunk_size = self.prefill_chunk_size + model_config.max_batch_size = self.max_batch_size + if self.num_shards is not None: + logger.info( + "Overriding %s from %d to %d", + bold("num_shards"), + model_config.num_shards, + self.num_shards, + ) + model_config.num_shards = self.num_shards + + @staticmethod + def from_str(source: str) -> "ModelConfigOverride": + """Parse model config override values from a string.""" + + parser = argparse.ArgumentParser(description="model config override values") + parser.add_argument("--context_window_size", type=int, default=None) + parser.add_argument("--prefill_chunk_size", type=int, default=None) + parser.add_argument("--sliding_window", type=int, default=None) + parser.add_argument("--max_batch_size", type=int, default=None) + parser.add_argument("--num_shards", type=int, default=None) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return ModelConfigOverride( + context_window_size=results.context_window_size, + prefill_chunk_size=results.prefill_chunk_size, + sliding_window=results.sliding_window, + max_batch_size=results.max_batch_size, + num_shards=results.num_shards, + ) diff --git a/python/mlc_chat/compiler/flags_optimization.py b/python/mlc_chat/compiler/flags_optimization.py index 704903b419..bb60adae00 100644 --- a/python/mlc_chat/compiler/flags_optimization.py +++ b/python/mlc_chat/compiler/flags_optimization.py @@ -6,7 +6,7 @@ @dataclasses.dataclass class OptimizationFlags: - """Optiization flags""" + """Optimization flags""" cutlass_attn: bool = True cutlass_norm: bool = True diff --git a/python/mlc_chat/compiler/gen_mlc_chat_config.py b/python/mlc_chat/compiler/gen_mlc_chat_config.py index 0b0ecd667b..ddee7cc46f 100644 --- a/python/mlc_chat/compiler/gen_mlc_chat_config.py +++ b/python/mlc_chat/compiler/gen_mlc_chat_config.py @@ -28,7 +28,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes quantization: str = None model_config: Dict[str, Any] = None vocab_size: int = None - max_window_size: int = None + context_window_size: int = None temperature: float = None repetition_penalty: float = None @@ -74,7 +74,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b model_config=model_config_json, vocab_size=model_config.vocab_size, conv_template=conv_template, - max_window_size=model_config.context_window_size, + context_window_size=model_config.context_window_size, ) # Step 1. Load `config.json` for key, value in model_config.__dict__.items(): @@ -140,8 +140,8 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b mlc_chat_config_dict = dataclasses.asdict(mlc_chat_config) if mlc_chat_config_dict["sliding_window"] is not None: - del mlc_chat_config_dict["max_window_size"] - logger.info("[CleanUp] Deleting %s", bold("max_window_size")) + del mlc_chat_config_dict["context_window_size"] + logger.info("[CleanUp] Deleting %s", bold("context_window_size")) for key, value in list(mlc_chat_config_dict.items()): if value is None: del mlc_chat_config_dict[key] diff --git a/python/mlc_chat/compiler/help.py b/python/mlc_chat/compiler/help.py index feb16d45cc..71c15dd066 100644 --- a/python/mlc_chat/compiler/help.py +++ b/python/mlc_chat/compiler/help.py @@ -99,5 +99,11 @@ (Experimental) The chunk size during prefilling. By default, the chunk size is the same as sliding window or max sequence length. This flag subjects to future refactoring. +""".strip(), + "overrides": """ +Model configuration override. Configurations to override `mlc-chat-config.json`. +Supports `context_window_size`, `prefill_chunk_size`, `sliding_window`, `max_batch_size` +and `num_shards`. Meanwhile, model config could be explicitly specified via details +knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". """.strip(), } diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index d266403355..5515594c05 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -15,6 +15,38 @@ FOUND = green("Found") +def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path: + """Detect and return the path that points to mlc-chat-config.json. + If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it. + + Parameters + --------- + mlc_chat_config : Union[str, pathlib.Path] + The path to `mlc-chat-config.json`, or the directory containing + `mlc-chat-config.json`. + + Returns + ------- + mlc_chat_config_json_path : pathlib.Path + The path points to mlc_chat_config.json. + """ + + mlc_chat_config_path = Path(mlc_chat_config) + if not mlc_chat_config_path.exists(): + raise ValueError(f"{mlc_chat_config_path} does not exist.") + + if mlc_chat_config_path.is_dir(): + # search config.json under config path + mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" + if not mlc_chat_config_json_path.exists(): + raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + else: + mlc_chat_config_json_path = mlc_chat_config_path + + logger.info("%s model configuration: %s", FOUND, mlc_chat_config_json_path) + return mlc_chat_config_json_path + + def detect_config(config: Union[str, Path]) -> Path: """Detect and return the path that points to config.json. If `config` is a directory, it looks for config.json below it. diff --git a/python/mlc_chat/support/config.py b/python/mlc_chat/support/config.py index 9e42b815bc..0f06a9707a 100644 --- a/python/mlc_chat/support/config.py +++ b/python/mlc_chat/support/config.py @@ -40,6 +40,17 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] fields = {k: v for k, v in source.items() if k in field_names} kwargs = {k: v for k, v in source.items() if k not in field_names} + if "model_config" in source: + fields |= { + k: v + for k, v in source["model_config"].items() + if (k in field_names) and (k not in fields) + } + kwargs |= { + k: v + for k, v in source["model_config"].items() + if (k not in field_names) and (k not in fields) + } return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] @classmethod