Skip to content

Commit

Permalink
Establish mlc_chat.compiler (mlc-ai#1082)
Browse files Browse the repository at this point in the history
This PR establishes the compiler components in MLC-Chat Python API,
which currently includes two primary components: models and parameters.

The models are `nn.Module`-based definition of an LLM, which, as the
very first stab, contains only `LlamaForCasualLM`. It is decomposed into
three files:
- `llama_config.py`: common configurations for Llama, where we define
  relevant configurations of its architecture, as well as include
  standard config file for Llama2-7B/13B/70B for convenient testing;
- `llama.py`: the model architecture of Llama, based on the PyTorch-like
`nn.Module` API;
- `llama_parameter.py`: defines the mapping between MLC parameters and
  pytorch parameters.

The parameters contains the basic functionality of parameter mapping,
and the loaders that effectively convert parameters from PyTorch to MLC
according to the mapping specified. Currently, only `HFTorchLoader` is
implemented, but loaders like SafeTensor, GPTQ or AWQ should be quite
straightforward according to the existing design.

On top of this PR, on-the-fly quantization could be defined as a loading
time transformation on MLC parameters, while pre-quantized parameter
loading is effectively parameter loading after MLC's `nn.Module` is
quantized.

Two unittests examplify how the infrastructure works:
- `./tests/python/model/test_llama.py` shows how to create an `nn.Module`
using the new infra, and then convert it to TVM IRModule;
- `./tests/python/parameter/hf_torch_loader.py` shows how to load
parameters from HuggingFace PyTorch format.

Besides, `mlc_chat.support` is established for utility functions, which
now contains two utils:
- `config.py` which supports reading configurations into dataclasses
from JSON file or Python dict. On top of Python dataclass, it throws
irrelevant fields into `cls.kwargs`, which is helpful when loading
HuggingFace configuration file;
- `tqdm.py` which contains tqdm-related utilities, primarily redirecting
logging and printing to work nicely with tqdm.
  • Loading branch information
junrushao authored Oct 19, 2023
1 parent 3aefd9f commit 2625945
Show file tree
Hide file tree
Showing 20 changed files with 617 additions and 385 deletions.
3 changes: 0 additions & 3 deletions mlc_llm/models/__init__.py

This file was deleted.

6 changes: 0 additions & 6 deletions mlc_llm/param_loader/__init__.py

This file was deleted.

191 changes: 0 additions & 191 deletions mlc_llm/param_loader/hf_torch_loader.py

This file was deleted.

36 changes: 0 additions & 36 deletions mlc_llm/param_loader/param_mapping.py

This file was deleted.

26 changes: 24 additions & 2 deletions python/mlc_chat/cli/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A command line tool for benchmarking a chat model."""
import argparse

from mlc_chat import ChatModule
from mlc_chat import ChatConfig, ChatModule

parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.")
parser.add_argument(
Expand All @@ -13,6 +13,21 @@
the model folder over possible paths.""",
required=True,
)
parser.add_argument(
"--model-lib",
type=str,
help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static
library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat,
as the runtime of MLC LLM, depends on the compiled model library to generate tokens.
""",
required=False,
)
parser.add_argument(
"--num-shards",
type=int,
help="Number of GPUs to be used.",
required=False,
)
parser.add_argument(
"--device",
type=str,
Expand Down Expand Up @@ -40,7 +55,14 @@
def main():
"""The main function that runs the benchmarking."""
args = parser.parse_args()
chat_module = ChatModule(model=args.model, device=args.device)
chat_module = ChatModule(
model=args.model,
device=args.device,
chat_config=ChatConfig(
num_shards=args.num_shards,
),
lib_path=args.model_lib,
)
output = chat_module.benchmark_generate(args.prompt, generate_length=args.generate_length)
print(f"Generated text:\n{output}\n")
print(f"Statistics: {chat_module.stats(verbose=True)}")
Expand Down
5 changes: 5 additions & 0 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency,
but users could optionally import it if they want to use the compiler.
"""
from . import model, parameter
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Model definition for the compiler."""
from . import llama, llama_config, llama_parameter
Original file line number Diff line number Diff line change
@@ -1,41 +1,19 @@
"""Implementation for Llama2 architecture"""
import dataclasses
"""
Implementation for Llama2 architecture.
TODO: add docstring
"""
import math
from typing import Any, Dict, Optional
from typing import Optional

from tvm import te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op

from .model_config_base import ModelConfig
from .llama_config import LlamaConfig

# pylint: disable=invalid-name,missing-docstring


@dataclasses.dataclass
class LlamaConfig(ModelConfig): # pylint: disable=too-many-instance-attributes
hidden_act: str
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
max_sequence_length: int = 2048
position_embedding_base: int = 10000
num_key_value_heads: int = 0
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
head_dim: int = 0

def __post_init__(self):
if self.num_key_value_heads == 0:
self.num_key_value_heads = self.num_attention_heads
if self.head_dim == 0:
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.num_attention_heads % self.num_key_value_heads == 0
assert self.head_dim * self.num_attention_heads == self.hidden_size


class RotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
Expand Down
Loading

0 comments on commit 2625945

Please sign in to comment.