Skip to content

Commit

Permalink
add mlc presets
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpissarra committed Dec 6, 2023
1 parent adc2949 commit 97c5156
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 43 deletions.
4 changes: 1 addition & 3 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
HELP,
MODELS,
QUANTIZATION,
ModelConfigOverride,
OptimizationFlags,
compile,
Expand Down Expand Up @@ -97,8 +96,7 @@ def _check_system_lib_prefix(prefix: str) -> str:
with open(parsed.model, "r", encoding="utf-8") as config_file:
config = json.load(config_file)
compile(
config=config["model_config"],
quantization=QUANTIZATION[config["quantization"]],
config=config,
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
from .gen_config import CONV_TEMPLATES, gen_config
from .help import HELP
from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping
from .model import MODEL_PRESETS, MODELS, Model
from .model import HF_PRESETS, MLC_PRESETS, MODELS, Model
from .quantization import QUANTIZATION
6 changes: 3 additions & 3 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .flags_model_config_override import ModelConfigOverride
from .flags_optimization import OptimizationFlags
from .model import Model
from .quantization import Quantization
from .quantization import QUANTIZATION, Quantization

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,7 +123,6 @@ def _compile(args: CompileArgs, model_config: ConfigBase):

def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Dict[str, Any],
quantization: str,
model_type: Model,
target: Target,
opt: OptimizationFlags,
Expand All @@ -133,7 +132,8 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
overrides: ModelConfigOverride,
):
"""Compile a model given its configuration and quantization format to a specific target."""
model_config = model_type.config.from_dict(config)
model_config = model_type.config.from_dict({**config["model_config"], **config})
quantization = QUANTIZATION[config["quantization"]]
args = CompileArgs(
model_config,
quantization,
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/compiler/help.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Help message for CLI arguments."""
from .model import MODEL_PRESETS
from .model import HF_PRESETS

HELP = {
"model": (
Expand All @@ -19,7 +19,7 @@
Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main.
Pre-defined model architectures include """
+ ", ".join(f'"{preset}"' for preset in MODEL_PRESETS)
+ ", ".join(f'"{preset}"' for preset in HF_PRESETS)
+ "."
).strip(),
"quantization": """
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Model definition for the compiler."""
from . import llama, mistral
from .model import MODEL_PRESETS, MODELS, Model
from .model import HF_PRESETS, MLC_PRESETS, MODELS, Model
57 changes: 52 additions & 5 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Model:
),
}

MODEL_PRESETS: Dict[str, Any] = {
HF_PRESETS: Dict[str, Any] = {
"llama2_7b": {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
Expand All @@ -106,7 +106,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
Expand All @@ -131,7 +130,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 13824,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 40,
"num_hidden_layers": 40,
Expand All @@ -155,7 +153,6 @@ class Model:
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 2048,
"context_window_size": 4096,
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 80,
Expand Down Expand Up @@ -253,7 +250,6 @@ class Model:
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 10000.0,
"sliding_window": 4096,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.34.0.dev0",
Expand All @@ -277,3 +273,54 @@ class Model:
"vocab_size": 50257,
},
}

MLC_PRESETS: Dict[str, Any] = {
"llama2_7b": {
"model_config": HF_PRESETS["llama2_7b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"llama2_13b": {
"model_config": HF_PRESETS["llama2_13b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"llama2_70b": {
"model_config": HF_PRESETS["llama2_70b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"codellama_7b": {
"model_config": HF_PRESETS["codellama_7b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"codellama_13b": {
"model_config": HF_PRESETS["codellama_13b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"codellama_34b": {
"model_config": HF_PRESETS["codellama_34b"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
"mistral_7b": {
"model_config": HF_PRESETS["mistral_7b"],
"sliding_window": 4096,
"prefill_chunk_size": 128,
"quantization": "q4f16_1",
},
"gpt2": {
"model_config": HF_PRESETS["gpt2"],
"context_window_size": 2048,
"prefill_chunk_size": 2048,
"quantization": "q4f16_1",
},
}
47 changes: 30 additions & 17 deletions python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from . import logging
from .style import bold, green
from .download import download_mlc_weights
from .style import bold, green

if TYPE_CHECKING:
from mlc_chat.compiler import Model # pylint: disable=unused-import
Expand All @@ -16,13 +16,13 @@
FOUND = green("Found")


def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
"""Detect and return the path that points to mlc-chat-config.json.
If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it.
Parameters
---------
mlc_chat_config : Union[str, pathlib.Path]
mlc_chat_config : str
The path to `mlc-chat-config.json`, or the directory containing
`mlc-chat-config.json`.
Expand All @@ -31,16 +31,29 @@ def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
mlc_chat_config_json_path : pathlib.Path
The path points to mlc_chat_config.json.
"""
from mlc_chat.compiler import MLC_PRESETS # pylint: disable=import-outside-toplevel

if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"):
mlc_chat_config = download_mlc_weights(model_url=mlc_chat_config)

mlc_chat_config_path = Path(mlc_chat_config)
mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config))
elif isinstance(mlc_chat_config, str) and mlc_chat_config in MLC_PRESETS:
logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config)
content = MLC_PRESETS[mlc_chat_config].copy()
content["model_preset_tag"] = mlc_chat_config
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
suffix=".json",
delete=False,
)
logger.info("Dumping config to: %s", temp_file.name)
mlc_chat_config_path = Path(temp_file.name)
with mlc_chat_config_path.open("w", encoding="utf-8") as mlc_chat_config_file:
json.dump(content, mlc_chat_config_file, indent=2)
else:
mlc_chat_config_path = Path(mlc_chat_config)
if not mlc_chat_config_path.exists():
raise ValueError(f"{mlc_chat_config_path} does not exist.")

if mlc_chat_config_path.is_dir():
# search config.json under config path
# search mlc-chat-config.json under path
mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json"
if not mlc_chat_config_json_path.exists():
raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.")
Expand All @@ -51,13 +64,13 @@ def detect_mlc_chat_config(mlc_chat_config: Union[str, Path]) -> Path:
return mlc_chat_config_json_path


def detect_config(config: Union[str, Path]) -> Path:
def detect_config(config: str) -> Path:
"""Detect and return the path that points to config.json. If `config` is a directory,
it looks for config.json below it.
Parameters
---------
config : Union[str, pathlib.Path]
config : str
The preset name of the model, or the path to `config.json`, or the directory containing
`config.json`.
Expand All @@ -66,13 +79,11 @@ def detect_config(config: Union[str, Path]) -> Path:
config_json_path : pathlib.Path
The path points to config.json.
"""
from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel
MODEL_PRESETS,
)
from mlc_chat.compiler import HF_PRESETS # pylint: disable=import-outside-toplevel

if isinstance(config, str) and config in MODEL_PRESETS:
if isinstance(config, str) and config in HF_PRESETS:
logger.info("%s preset model: %s", FOUND, config)
content = MODEL_PRESETS[config].copy()
content = HF_PRESETS[config].copy()
content["model_preset_tag"] = config
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
suffix=".json",
Expand Down Expand Up @@ -126,12 +137,14 @@ def detect_model_type(model_type: str, config: Path) -> "Model":
if model_type == "auto":
with open(config, "r", encoding="utf-8") as config_file:
cfg = json.load(config_file)
if "model_type" not in cfg:
if "model_type" not in cfg and (
"model_config" not in cfg or "model_type" not in cfg["model_config"]
):
raise ValueError(
f"'model_type' not found in: {config}. "
f"Please explicitly specify `--model-type` instead."
)
model_type = cfg["model_type"]
model_type = cfg["model_type"] if "model_type" in cfg else cfg["model_config"]["model_type"]
logger.info("%s model type: %s. Use `--model-type` to override.", FOUND, bold(model_type))
if model_type not in MODELS:
raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}")
Expand Down
4 changes: 2 additions & 2 deletions tests/python/loader/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import tvm

from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION
from mlc_chat.compiler import HF_PRESETS, MODELS, QUANTIZATION
from mlc_chat.compiler.loader import HuggingFaceLoader
from mlc_chat.support import logging, tqdm

Expand All @@ -24,7 +24,7 @@ def test_load_llama(param_path: Union[str, Path]):

model = MODELS["llama"]
quantization = QUANTIZATION["q4f16_awq"]
config = model.config.from_dict(MODEL_PRESETS["llama2_7b"])
config = model.config.from_dict(HF_PRESETS["llama2_7b"])
loader = HuggingFaceLoader(
path=path_params,
extern_param_map=model.source["awq"](config, quantization),
Expand Down
4 changes: 2 additions & 2 deletions tests/python/model/test_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODEL_PRESETS, MODELS
from mlc_chat.compiler import HF_PRESETS, MODELS


@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_creation(model_name: str):
model_info = MODELS["gpt2"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
config = model_info.config.from_dict(HF_PRESETS[model_name])
model = model_info.model(config)
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions tests/python/model/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODEL_PRESETS, MODELS
from mlc_chat.compiler import HF_PRESETS, MODELS


@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"])
def test_llama2_creation(model_name: str):
model_info = MODELS["llama"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
config = model_info.config.from_dict(HF_PRESETS[model_name])
model = model_info.model(config)
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions tests/python/model/test_llama_quantization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION
from mlc_chat.compiler import HF_PRESETS, MODELS, QUANTIZATION
from mlc_chat.compiler.quantization.group_quantization import (
GroupQuantizeEmbedding,
GroupQuantizeLinear,
Expand All @@ -18,7 +18,7 @@
)
def test_llama2_group_quantization(model_name: str, quant_name: str):
model_info = MODELS["llama"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
config = model_info.config.from_dict(HF_PRESETS[model_name])
model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name])
assert "model.embed_tokens.weight" in quant_map.param_map
assert isinstance(
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_llama2_group_quantization(model_name: str, quant_name: str):
)
def test_llama2_no_quantization(model_name: str, quant_name: str):
model_info = MODELS["llama"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
config = model_info.config.from_dict(HF_PRESETS[model_name])
_, quant_map = model_info.quantize["no-quant"](config, QUANTIZATION[quant_name])
assert len(quant_map.param_map) == 0
assert len(quant_map.map_func) == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/python/model/test_mistral.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# pylint: disable=invalid-name,missing-docstring
import pytest

from mlc_chat.compiler import MODEL_PRESETS, MODELS
from mlc_chat.compiler import HF_PRESETS, MODELS


@pytest.mark.parametrize("model_name", ["mistral_7b"])
def test_mistral_creation(model_name: str):
model_info = MODELS["mistral"]
config = model_info.config.from_dict(MODEL_PRESETS[model_name])
config = model_info.config.from_dict(HF_PRESETS[model_name])
model = model_info.model(config)
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
Expand Down

0 comments on commit 97c5156

Please sign in to comment.