Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Oct 30, 2023
1 parent dd79bc6 commit ba3e8b8
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 185 deletions.
6 changes: 3 additions & 3 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
MODELS,
QUANT,
QUANTIZATION,
OptimizationFlags,
compile,
)
Expand Down Expand Up @@ -51,7 +51,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
"--quantization",
type=str,
required=True,
choices=list(QUANT.keys()),
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
)
parser.add_argument(
Expand Down Expand Up @@ -119,7 +119,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
parsed.model_type = detect_model_type(parsed.model_type, parsed.config)
compile(
config=parsed.config,
quantization=parsed.quantization,
quantization=QUANTIZATION[parsed.quantization],
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .flags_optimization import OptimizationFlags
from .model import MODEL_PRESETS, MODELS, Model
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .quantization import QUANT
from .quantization import QUANTIZATION
16 changes: 8 additions & 8 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from tvm import IRModule, relax
from tvm.target import Target

from ..compiler.model import Model
from ..support.style import bold
from .flags_optimization import OptimizationFlags
from .model import Model
from .quantization import Quantization


@dataclasses.dataclass
class CompileArgs: # pylint: disable=too-many-instance-attributes
"""Arguments to MLC LLM's compiler."""

config: Path
quantization: str
quantization: Quantization
model: Model
target: Target
opt: OptimizationFlags
Expand All @@ -40,20 +41,19 @@ def _echo_args(args: CompileArgs) -> None:

def _compile(args: CompileArgs):
model_config = args.model.config.from_file(args.config)
model = args.model.model(model_config)
mod, named_params = model.export_tvm(
quantization = args.quantization
model, _ = args.model.quantize[quantization.kind](model_config, quantization)
mod, _named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
with args.target:
mod = relax.get_pipeline("mlc_llm")(mod)
mod.show(black_format=False)
for name, param in named_params:
print(f"{name}: {param.shape} {param.dtype}")
args.build_func(mod, args)


def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Path,
quantization,
quantization: Quantization,
model_type: Model,
target: Target,
opt: OptimizationFlags,
Expand Down
7 changes: 6 additions & 1 deletion python/mlc_chat/compiler/model/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self, config: LlamaConfig):

def forward(self, q: Tensor, k: Tensor, offset: tir.Var):
def te_op(x: te.Tensor, offset: tir.Var):
dtype = x.dtype

def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
head_dim = tir.const(self.head_dim, "int32")
position_embedding_base = tir.const(self.position_embedding_base, "float32")
Expand All @@ -30,11 +32,13 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
(d * 2 % head_dim).astype("float32") / head_dim,
)
freq = (offset + s) / freq
return tir.cos(freq) * x[b, s, h, d] + tir.sin(freq) * tir.if_then_else(
cos = tir.cos(freq).astype(dtype) * x[b, s, h, d]
sin = tir.sin(freq).astype(dtype) * tir.if_then_else(
d < self.head_dim // 2,
-x[b, s, h, d + self.head_dim // 2],
x[b, s, h, d - self.head_dim // 2],
)
return cos + sin

return te.compute(x.shape, compute, name="rotary")

Expand Down Expand Up @@ -87,6 +91,7 @@ def forward( # pylint: disable=too-many-locals
d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len
b, s, _ = hidden_states.shape
assert b == 1, "Only support batch size 1 at this moment."

q, k, v = self.qkv_proj(hidden_states)
q = op.reshape(q, (b, s, h_q, d))
k = op.reshape(k, (b, s, h_kv, d))
Expand Down
26 changes: 12 additions & 14 deletions python/mlc_chat/compiler/model/llama_quantization.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
"""
Quantization specs for Llama2 architecture.
TODO: add docstring
"""
"""Quantization specs for Llama."""
from typing import Tuple

from tvm.relax.frontend import nn

from ..parameter import QuantizeMapping
from ..quantization import GroupQuantizeConfig
from ..quantization import GroupQuantize
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM


def llama_group_quantization(
model: LlamaForCasualLM, quant_config: GroupQuantizeConfig
def group_quant(
model_config: LlamaConfig,
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Llama2 model using group quantization."""
model: nn.Module = LlamaForCasualLM(model_config)
quant_map = QuantizeMapping({}, {})
for i in range(len(model.model.layers)):
model.model.layers[i] = quant_config.apply(
model.model.layers[i], quant_map, f"model.layers.{i}"
)
model.model.embed_tokens = quant_config.apply(
model.model.embed_tokens, quant_map, "model.embed_tokens"
model = quantization.quantize_model(
model,
quant_map,
"model",
)
model.lm_head = quant_config.apply(model.lm_head, quant_map, "lm_head")
return model, quant_map
21 changes: 12 additions & 9 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""A centralized registry of all existing model architures and their configurations."""
import dataclasses
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Tuple

from tvm.relax.frontend import nn

from ..parameter import ExternMapping, QuantizeMapping
from ..quantization.quantization import QuantizeConfig
from . import llama_config, llama_model, llama_parameter
from ..quantization.quantization import Quantization
from . import llama_config, llama_model, llama_parameter, llama_quantization

ModelConfig = Any
"""A ModelConfig is an object that represents a model architecture. It is required to have
Expand All @@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig:
...
"""

FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]
FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping]
FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]]


@dataclasses.dataclass
Expand All @@ -38,15 +38,16 @@ class Model:
source : Dict[str, FuncGetExternMap]
A dictionary that maps the name of a source format to parameter mapping.
quantize: Dict[str, FuncGetQuantMap]
A dictionary that maps the name of a quantization method to quantization mapping.
quantize: Dict[str, FuncQuantization]
A dictionary that maps the name of a quantization method to quantized model and the
quantization parameter mapping.
"""

name: str
config: ModelConfig
model: Callable[[ModelConfig], nn.Module]
source: Dict[str, FuncGetExternMap]
quantize: Dict[str, FuncGetQuantMap]
quantize: Dict[str, FuncQuantization]


MODELS: Dict[str, Model] = {
Expand All @@ -58,7 +59,9 @@ class Model:
"huggingface-torch": llama_parameter.huggingface,
"huggingface-safetensor": llama_parameter.huggingface,
},
quantize={},
quantize={
"group-quant": llama_quantization.group_quant,
},
)
}

Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/compiler/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A subpackage for quantization and dequantization algorithms"""
from .quantization import QUANT, QuantizeConfig
from .group_quantization import GroupQuantizeConfig
from .group_quantization import GroupQuantize
from .quantization import QUANTIZATION, Quantization
Loading

0 comments on commit ba3e8b8

Please sign in to comment.