Skip to content

Commit

Permalink
[SLM] UX improvement: HF download & presets (#1374)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpissarra authored Dec 11, 2023
1 parent bc5cb4b commit 002dc89
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 24 deletions.
19 changes: 16 additions & 3 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
)

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_mlc_chat_config, detect_model_type
from ..support.auto_config import (
detect_mlc_chat_config,
detect_model_type,
detect_quantization,
)
from ..support.auto_target import detect_target_and_host


Expand Down Expand Up @@ -46,6 +50,13 @@ def _check_system_lib_prefix(prefix: str) -> str:
type=detect_mlc_chat_config,
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=False,
choices=list(QUANTIZATION.keys()),
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
Expand Down Expand Up @@ -93,11 +104,13 @@ def _check_system_lib_prefix(prefix: str) -> str:
parsed = parser.parse_args(argv)
target, build_func = detect_target_and_host(parsed.device, parsed.host)
parsed.model_type = detect_model_type(parsed.model_type, parsed.model)
parsed.quantization = detect_quantization(parsed.quantization, parsed.model)
with open(parsed.model, "r", encoding="utf-8") as config_file:
config = json.load(config_file)

compile(
config=config["model_config"],
quantization=QUANTIZATION[config["quantization"]],
config=config,
quantization=parsed.quantization,
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .help import HELP
from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping
from .model import MODEL_PRESETS, MODELS, Model
from .quantization import QUANTIZATION
from .quantization import QUANTIZATION, Quantization
8 changes: 6 additions & 2 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _compile(args: CompileArgs, model_config: ConfigBase):

def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Dict[str, Any],
quantization: str,
quantization: Quantization,
model_type: Model,
target: Target,
opt: OptimizationFlags,
Expand All @@ -163,7 +163,11 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
overrides: ModelConfigOverride,
):
"""Compile a model given its configuration and quantization format to a specific target."""
model_config = model_type.config.from_dict(config)
if "model_config" in config:
model_config = model_type.config.from_dict({**config["model_config"], **config})
else:
model_config = model_type.config.from_dict(config)

args = CompileArgs(
model_config,
quantization,
Expand Down
20 changes: 16 additions & 4 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
Expand All @@ -134,6 +133,8 @@ class Model:
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"llama2_13b": {
"_name_or_path": "meta-llama/Llama-2-13b-hf",
Expand All @@ -145,7 +146,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
Expand All @@ -159,6 +159,8 @@ class Model:
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"llama2_70b": {
"architectures": ["LlamaForCausalLM"],
Expand All @@ -169,7 +171,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
Expand All @@ -181,6 +182,8 @@ class Model:
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"codellama_7b": {
"_name_or_path": "codellama/CodeLlama-7b-hf",
Expand All @@ -205,6 +208,8 @@ class Model:
"transformers_version": "4.33.0.dev0",
"use_cache": True,
"vocab_size": 32016,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"codellama_13b": {
"architectures": ["LlamaForCausalLM"],
Expand All @@ -228,6 +233,8 @@ class Model:
"transformers_version": "4.32.0.dev0",
"use_cache": True,
"vocab_size": 32016,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"codellama_34b": {
"architectures": ["LlamaForCausalLM"],
Expand All @@ -251,6 +258,8 @@ class Model:
"transformers_version": "4.32.0.dev0",
"use_cache": True,
"vocab_size": 32016,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"mistral_7b": {
"architectures": ["MistralForCausalLM"],
Expand All @@ -267,12 +276,13 @@ class Model:
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 10000.0,
"sliding_window": 4096,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.34.0.dev0",
"use_cache": True,
"vocab_size": 32000,
"sliding_window": 4096,
"prefill_chunk_size": 128,
},
"gpt2": {
"architectures": ["GPT2LMHeadModel"],
Expand All @@ -289,6 +299,8 @@ class Model:
"transformers_version": "4.26.0.dev0",
"use_cache": True,
"vocab_size": 50257,
"context_window_size": 2048,
"prefill_chunk_size": 2048,
},
"redpajama_3b_v1": {
"_name_or_path": "/root/fm/models/rp_3b_800b_real_fp16",
Expand Down
77 changes: 68 additions & 9 deletions python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
import json
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from mlc_chat.compiler import QUANTIZATION, Quantization

from . import logging
from .download import download_mlc_weights
from .style import bold, green

if TYPE_CHECKING:
from mlc_chat.compiler import Model # pylint: disable=unused-import


logger = logging.getLogger(__name__)

FOUND = green("Found")


def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
"""Detect and return the path that points to mlc-chat-config.json.
If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it.
Parameters
---------
mlc_chat_config : Union[str, pathlib.Path]
mlc_chat_config : str
The path to `mlc-chat-config.json`, or the directory containing
`mlc-chat-config.json`.
Expand All @@ -30,13 +34,31 @@ def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
mlc_chat_config_json_path : pathlib.Path
The path points to mlc_chat_config.json.
"""
from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel
MODEL_PRESETS,
)

mlc_chat_config_path = Path(mlc_chat_config)
if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"):
mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config))
elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS:
logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config)
content = MODEL_PRESETS[mlc_chat_config].copy()
content["model_preset_tag"] = mlc_chat_config
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
suffix=".json",
delete=False,
)
logger.info("Dumping config to: %s", temp_file.name)
mlc_chat_config_path = Path(temp_file.name)
with mlc_chat_config_path.open("w", encoding="utf-8") as mlc_chat_config_file:
json.dump(content, mlc_chat_config_file, indent=2)
else:
mlc_chat_config_path = Path(mlc_chat_config)
if not mlc_chat_config_path.exists():
raise ValueError(f"{mlc_chat_config_path} does not exist.")

if mlc_chat_config_path.is_dir():
# search config.json under config path
# search mlc-chat-config.json under path
mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json"
if not mlc_chat_config_json_path.exists():
raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.")
Expand All @@ -47,13 +69,13 @@ def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
return mlc_chat_config_json_path


def detect_config(config: Union[str, Path]) -> Path:
def detect_config(config: str) -> Path:
"""Detect and return the path that points to config.json. If `config` is a directory,
it looks for config.json below it.
Parameters
---------
config : Union[str, pathlib.Path]
config : str
The preset name of the model, or the path to `config.json`, or the directory containing
`config.json`.
Expand Down Expand Up @@ -122,13 +144,50 @@ def detect_model_type(model_type: str, config: Path) -> "Model":
if model_type == "auto":
with open(config, "r", encoding="utf-8") as config_file:
cfg = json.load(config_file)
if "model_type" not in cfg:
if "model_type" not in cfg and (
"model_config" not in cfg or "model_type" not in cfg["model_config"]
):
raise ValueError(
f"'model_type' not found in: {config}. "
f"Please explicitly specify `--model-type` instead."
)
model_type = cfg["model_type"]
model_type = cfg["model_type"] if "model_type" in cfg else cfg["model_config"]["model_type"]
logger.info("%s model type: %s. Use `--model-type` to override.", FOUND, bold(model_type))
if model_type not in MODELS:
raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}")
return MODELS[model_type]


def detect_quantization(quantization_arg: str, config: Path) -> Quantization:
"""Detect the model quantization scheme from the configuration file or `--quantization`
argument. If `--quantization` is provided, it will override the value on the configuration
file.
Parameters
----------
quantization_arg : str
The quantization scheme, for example, "q4f16_1".
config : pathlib.Path
The path to mlc-chat-config.json.
Returns
-------
quantization : mlc_chat.compiler.Quantization
The model quantization scheme.
"""

with open(config, "r", encoding="utf-8") as config_file:
cfg = json.load(config_file)

if quantization_arg is not None:
quantization = QUANTIZATION[quantization_arg]
elif "quantization" in cfg:
quantization = QUANTIZATION[cfg["quantization"]]
else:
raise ValueError(
f"'quantization' not found in: {config}. "
f"Please explicitly specify `--quantization` instead."
)

return quantization
12 changes: 7 additions & 5 deletions python/mlc_chat/support/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,24 @@ def download_mlc_weights( # pylint: disable=too-many-locals
model_url: str,
num_processes: int = 4,
force_redo: bool = False,
) -> None:
) -> Path:
"""Download weights for a model from the HuggingFace Git LFS repo."""
mlc_prefix = "HF://"
prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], ""
mlc_prefix = next(p for p in prefixes if model_url.startswith(p))
assert mlc_prefix

git_url_template = "https://huggingface.co/{user}/{repo}.git"
bin_url_template = "https://huggingface.co/{user}/{repo}/resolve/main/{record_name}"

if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix):
raise ValueError(f"Invalid model URL: {model_url}")
assert model_url.startswith(mlc_prefix)
user, repo = model_url[len(mlc_prefix) :].split("/")
git_dir = get_cache_dir() / "model_weights" / repo
try:
_ensure_directory_not_exist(git_dir, force_redo=force_redo)
except ValueError:
logger.info("Weights already downloaded: %s", git_dir)
return
return git_dir
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix:
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
Expand All @@ -166,4 +168,4 @@ def download_mlc_weights( # pylint: disable=too-many-locals
logger.info("Downloaded %s to %s", file_url, file_dest)
logger.info("Moving %s to %s", tmp_dir, git_dir)
shutil.move(str(tmp_dir), str(git_dir))
shutil.move(str(tmp_dir), str(git_dir))
return git_dir

0 comments on commit 002dc89

Please sign in to comment.