Skip to content

Commit

Permalink
[SLM] UX mlc_chat compile improvement (#1371)
Browse files Browse the repository at this point in the history
* `config.json` to `mlc-chat-config.json`

* remove overlap args

* remove quantization arg overlap

* add `--overrides` arg

* add `--overrides` help

* fix lint
  • Loading branch information
davidpissarra authored Dec 3, 2023
1 parent 65506f3 commit 9200380
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 63 deletions.
38 changes: 8 additions & 30 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
HELP,
MODELS,
QUANTIZATION,
ModelConfigOverride,
OptimizationFlags,
compile,
)

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_config import detect_mlc_chat_config, detect_model_type
from ..support.auto_target import detect_target_and_host


Expand All @@ -39,18 +39,11 @@ def _check_system_lib_prefix(prefix: str) -> str:
parser = ArgumentParser("MLC LLM Compiler")
parser.add_argument(
"--model",
type=detect_config,
type=detect_mlc_chat_config,
required=True,
dest="config",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
Expand Down Expand Up @@ -82,12 +75,6 @@ def _check_system_lib_prefix(prefix: str) -> str:
default="",
help=HELP["system_lib_prefix"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--context-window-size",
type=int,
default=None,
help=HELP["context_window_size"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
Expand All @@ -96,30 +83,21 @@ def _check_system_lib_prefix(prefix: str) -> str:
help=HELP["output_compile"] + " (required)",
)
parser.add_argument(
"--sliding-window",
type=int,
default=None,
help=HELP["sliding_window"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--prefill-chunk-size",
type=int,
default=None,
help=HELP["prefill_chunk_size"] + ' (default: "%(default)s")',
"--overrides",
type=ModelConfigOverride.from_str,
default="",
help=HELP["overrides"] + ' (default: "%(default)s")',
)
parsed = parser.parse_args(argv)
target, build_func = detect_target_and_host(parsed.device, parsed.host)
parsed.model_type = detect_model_type(parsed.model_type, parsed.config)
compile(
config=parsed.config,
quantization=QUANTIZATION[parsed.quantization],
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
build_func=build_func,
system_lib_prefix=parsed.system_lib_prefix,
output=parsed.output,
context_window_size=parsed.context_window_size,
sliding_window=parsed.sliding_window,
prefill_chunk_size=parsed.prefill_chunk_size,
overrides=parsed.overrides,
)
24 changes: 9 additions & 15 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from io import StringIO
from pathlib import Path
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Tuple

from tvm import IRModule, relax
from tvm.relax.frontend import nn
Expand All @@ -15,7 +15,7 @@
from .flags_model_config_override import ModelConfigOverride
from .flags_optimization import OptimizationFlags
from .model import Model
from .quantization import Quantization
from .quantization import QUANTIZATION, Quantization

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,7 +45,7 @@ def display(self) -> None:
print(f" {bold('--opt'):<25} {self.opt}", file=out)
print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out)
print(f" {bold('--output'):<25} {self.output}", file=out)
print(f" {bold('--overrides'):<25} {dataclasses.asdict(self.overrides)}", file=out)
print(f" {bold('--overrides'):<25} {self.overrides}", file=out)
print(out.getvalue().rstrip())


Expand Down Expand Up @@ -106,9 +106,8 @@ def _attach_variable_bounds(mod, model_config):
mod[g_var] = func.with_attr("tir_var_upper_bound", tir_bound_map)


def _compile(args: CompileArgs):
def _compile(args: CompileArgs, model_config: ConfigBase):
logger.info("Creating model from: %s", args.config)
model_config = args.model.config.from_file(args.config)
args.overrides.apply(model_config)
model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization)
logger.info("Exporting the model to TVM Unity compiler")
Expand All @@ -127,18 +126,17 @@ def _compile(args: CompileArgs):

def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Path,
quantization: Quantization,
model_type: Model,
target: Target,
opt: OptimizationFlags,
build_func: Callable[[IRModule, CompileArgs], None],
system_lib_prefix: str,
output: Path,
context_window_size: Optional[int],
sliding_window: Optional[int],
prefill_chunk_size: Optional[int],
overrides: ModelConfigOverride,
):
"""Compile a model given its configuration and quantization format to a specific target."""
model_config = model_type.config.from_file(config)
quantization = QUANTIZATION[model_config.kwargs["quantization"]]
args = CompileArgs(
config,
quantization,
Expand All @@ -148,11 +146,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
build_func,
system_lib_prefix,
output,
ModelConfigOverride(
context_window_size=context_window_size,
sliding_window=sliding_window,
prefill_chunk_size=prefill_chunk_size,
),
overrides,
)
args.display()
_compile(args)
_compile(args, model_config)
66 changes: 53 additions & 13 deletions python/mlc_chat/compiler/flags_model_config_override.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Flags for overriding model config."""
import argparse
import dataclasses
import logging
from io import StringIO
from typing import Optional

from ..support.style import bold
Expand All @@ -13,10 +15,19 @@ class ModelConfigOverride:
"""Flags for overriding model config."""

context_window_size: Optional[int] = None
prefill_chunk_size: Optional[int] = None
sliding_window: Optional[int] = None
max_batch_size: Optional[int] = None
num_shards: Optional[int] = None
sliding_window: Optional[int] = None
prefill_chunk_size: Optional[int] = None

def __repr__(self) -> str:
out = StringIO()
print(f"context_window_size={self.context_window_size}", file=out, end="")
print(f";prefill_chunk_size={self.prefill_chunk_size}", file=out, end="")
print(f";sliding_window={self.sliding_window}", file=out, end="")
print(f";max_batch_size={self.max_batch_size}", file=out, end="")
print(f";num_shards={self.num_shards}", file=out, end="")
return out.getvalue().rstrip()

def apply(self, model_config):
"""Apply the overrides to the given model config."""
Expand All @@ -28,12 +39,14 @@ def apply(self, model_config):
self.context_window_size,
)
model_config.context_window_size = self.context_window_size
if self.max_batch_size is not None:
model_config.max_batch_size = self.max_batch_size
if self.num_shards is not None:
model_config.num_shards = self.num_shards

# Handle sliding window and sliding window chunk size
if self.prefill_chunk_size is not None:
logger.info(
"Overriding %s from %d to %d",
bold("prefill_chunk_size"),
model_config.prefill_chunk_size,
self.prefill_chunk_size,
)
model_config.prefill_chunk_size = self.prefill_chunk_size
if self.sliding_window is not None:
logger.info(
"Overriding %s from %d to %d",
Expand All @@ -50,11 +63,38 @@ def apply(self, model_config):
model_config.sliding_window,
)
model_config.prefill_chunk_size = self.sliding_window
if self.prefill_chunk_size is not None:
if self.max_batch_size is not None:
logger.info(
"Overriding %s from %d to %d",
bold("prefill_chunk_size"),
model_config.prefill_chunk_size,
self.prefill_chunk_size,
bold("max_batch_size"),
model_config.max_batch_size,
self.max_batch_size,
)
model_config.prefill_chunk_size = self.prefill_chunk_size
model_config.max_batch_size = self.max_batch_size
if self.num_shards is not None:
logger.info(
"Overriding %s from %d to %d",
bold("num_shards"),
model_config.num_shards,
self.num_shards,
)
model_config.num_shards = self.num_shards

@staticmethod
def from_str(source: str) -> "ModelConfigOverride":
"""Parse model config override values from a string."""

parser = argparse.ArgumentParser(description="model config override values")
parser.add_argument("--context_window_size", type=int, default=None)
parser.add_argument("--prefill_chunk_size", type=int, default=None)
parser.add_argument("--sliding_window", type=int, default=None)
parser.add_argument("--max_batch_size", type=int, default=None)
parser.add_argument("--num_shards", type=int, default=None)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return ModelConfigOverride(
context_window_size=results.context_window_size,
prefill_chunk_size=results.prefill_chunk_size,
sliding_window=results.sliding_window,
max_batch_size=results.max_batch_size,
num_shards=results.num_shards,
)
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/flags_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclasses.dataclass
class OptimizationFlags:
"""Optiization flags"""
"""Optimization flags"""

cutlass_attn: bool = True
cutlass_norm: bool = True
Expand Down
8 changes: 4 additions & 4 deletions python/mlc_chat/compiler/gen_mlc_chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
quantization: str = None
model_config: Dict[str, Any] = None
vocab_size: int = None
max_window_size: int = None
context_window_size: int = None

temperature: float = None
repetition_penalty: float = None
Expand Down Expand Up @@ -74,7 +74,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
model_config=model_config_json,
vocab_size=model_config.vocab_size,
conv_template=conv_template,
max_window_size=model_config.context_window_size,
context_window_size=model_config.context_window_size,
)
# Step 1. Load `config.json`
for key, value in model_config.__dict__.items():
Expand Down Expand Up @@ -140,8 +140,8 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b

mlc_chat_config_dict = dataclasses.asdict(mlc_chat_config)
if mlc_chat_config_dict["sliding_window"] is not None:
del mlc_chat_config_dict["max_window_size"]
logger.info("[CleanUp] Deleting %s", bold("max_window_size"))
del mlc_chat_config_dict["context_window_size"]
logger.info("[CleanUp] Deleting %s", bold("context_window_size"))
for key, value in list(mlc_chat_config_dict.items()):
if value is None:
del mlc_chat_config_dict[key]
Expand Down
6 changes: 6 additions & 0 deletions python/mlc_chat/compiler/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,11 @@
(Experimental) The chunk size during prefilling. By default,
the chunk size is the same as sliding window or max sequence length.
This flag subjects to future refactoring.
""".strip(),
"overrides": """
Model configuration override. Configurations to override `mlc-chat-config.json`.
Supports `context_window_size`, `prefill_chunk_size`, `sliding_window`, `max_batch_size`
and `num_shards`. Meanwhile, model config could be explicitly specified via details
knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128".
""".strip(),
}
32 changes: 32 additions & 0 deletions python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,38 @@
FOUND = green("Found")


def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
"""Detect and return the path that points to mlc-chat-config.json.
If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it.
Parameters
---------
mlc_chat_config : Union[str, pathlib.Path]
The path to `mlc-chat-config.json`, or the directory containing
`mlc-chat-config.json`.
Returns
-------
mlc_chat_config_json_path : pathlib.Path
The path points to mlc_chat_config.json.
"""

mlc_chat_config_path = Path(mlc_chat_config)
if not mlc_chat_config_path.exists():
raise ValueError(f"{mlc_chat_config_path} does not exist.")

if mlc_chat_config_path.is_dir():
# search config.json under config path
mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json"
if not mlc_chat_config_json_path.exists():
raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.")
else:
mlc_chat_config_json_path = mlc_chat_config_path

logger.info("%s model configuration: %s", FOUND, mlc_chat_config_json_path)
return mlc_chat_config_json_path


def detect_config(config: Union[str, Path]) -> Path:
"""Detect and return the path that points to config.json. If `config` is a directory,
it looks for config.json below it.
Expand Down
11 changes: 11 additions & 0 deletions python/mlc_chat/support/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:
field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type]
fields = {k: v for k, v in source.items() if k in field_names}
kwargs = {k: v for k, v in source.items() if k not in field_names}
if "model_config" in source:
fields |= {
k: v
for k, v in source["model_config"].items()
if (k in field_names) and (k not in fields)
}
kwargs |= {
k: v
for k, v in source["model_config"].items()
if (k not in field_names) and (k not in fields)
}
return cls(**fields, kwargs=kwargs) # type: ignore[call-arg]

@classmethod
Expand Down

0 comments on commit 9200380

Please sign in to comment.