Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable group quant transform with nn.Module #1154

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
MODELS,
QUANT,
QUANTIZATION,
OptimizationFlags,
compile,
)
Expand Down Expand Up @@ -51,7 +51,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
"--quantization",
type=str,
required=True,
choices=list(QUANT.keys()),
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
)
parser.add_argument(
Expand Down Expand Up @@ -119,7 +119,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
parsed.model_type = detect_model_type(parsed.model_type, parsed.config)
compile(
config=parsed.config,
quantization=parsed.quantization,
quantization=QUANTIZATION[parsed.quantization],
model_type=parsed.model_type,
target=target,
opt=parsed.opt,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .flags_optimization import OptimizationFlags
from .model import MODEL_PRESETS, MODELS, Model
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .quantization import QUANT
from .quantization import QUANTIZATION
16 changes: 8 additions & 8 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from tvm import IRModule, relax
from tvm.target import Target

from ..compiler.model import Model
from ..support.style import bold
from .flags_optimization import OptimizationFlags
from .model import Model
from .quantization import Quantization


@dataclasses.dataclass
class CompileArgs: # pylint: disable=too-many-instance-attributes
"""Arguments to MLC LLM's compiler."""

config: Path
quantization: str
quantization: Quantization
model: Model
target: Target
opt: OptimizationFlags
Expand All @@ -40,20 +41,19 @@ def _echo_args(args: CompileArgs) -> None:

def _compile(args: CompileArgs):
model_config = args.model.config.from_file(args.config)
model = args.model.model(model_config)
mod, named_params = model.export_tvm(
quantization = args.quantization
model, _ = args.model.quantize[quantization.kind](model_config, quantization)
mod, _named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
with args.target:
mod = relax.get_pipeline("mlc_llm")(mod)
mod.show(black_format=False)
for name, param in named_params:
print(f"{name}: {param.shape} {param.dtype}")
args.build_func(mod, args)


def compile( # pylint: disable=too-many-arguments,redefined-builtin
config: Path,
quantization,
quantization: Quantization,
model_type: Model,
target: Target,
opt: OptimizationFlags,
Expand Down
7 changes: 6 additions & 1 deletion python/mlc_chat/compiler/model/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self, config: LlamaConfig):

def forward(self, q: Tensor, k: Tensor, offset: tir.Var):
def te_op(x: te.Tensor, offset: tir.Var):
dtype = x.dtype

def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
head_dim = tir.const(self.head_dim, "int32")
position_embedding_base = tir.const(self.position_embedding_base, "float32")
Expand All @@ -30,11 +32,13 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
(d * 2 % head_dim).astype("float32") / head_dim,
)
freq = (offset + s) / freq
return tir.cos(freq) * x[b, s, h, d] + tir.sin(freq) * tir.if_then_else(
cos = tir.cos(freq).astype(dtype) * x[b, s, h, d]
sin = tir.sin(freq).astype(dtype) * tir.if_then_else(
d < self.head_dim // 2,
-x[b, s, h, d + self.head_dim // 2],
x[b, s, h, d - self.head_dim // 2],
)
return cos + sin

return te.compute(x.shape, compute, name="rotary")

Expand Down Expand Up @@ -87,6 +91,7 @@ def forward( # pylint: disable=too-many-locals
d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len
b, s, _ = hidden_states.shape
assert b == 1, "Only support batch size 1 at this moment."

q, k, v = self.qkv_proj(hidden_states)
q = op.reshape(q, (b, s, h_q, d))
k = op.reshape(k, (b, s, h_kv, d))
Expand Down
109 changes: 16 additions & 93 deletions python/mlc_chat/compiler/model/llama_quantization.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,24 @@
"""
Quantization specs for Llama2 architecture.
TODO: add docstring
"""
from typing import Callable, Dict, List, Optional
"""Quantization specs for Llama."""
from typing import Tuple

import tvm
from tvm.runtime import NDArray
from tvm.relax.frontend import nn

from ..parameter import QuantizeMapping
from ..quantization import QuantizeConfig
from ..quantization.group_quantizer import te_quantize as te_group_quantize
from ..quantization import GroupQuantize
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM


def huggingface_group_quantize(
def group_quant(
model_config: LlamaConfig,
quantize_config: QuantizeConfig,
target: Optional[tvm.target.Target] = None,
) -> QuantizeMapping:
"""Returns a parameter mapping that maps a parameter in MLC LLM's model
definition to its eventual names and values after quantization.

Parameters
----------
model_config : LlamaConfig
The configuration of the Llama model.
quantize_config : GroupQuantizeConfig
The configuration of the group quantization.
target : Optional[tvm.target.Target]
The target device to run the quantization on, by default None, which
means the quantization will be run on CPU.

Returns
-------
quantize_map : QuantizeMapping
The parameter mapping from a parameter in MLC LLM's model definition to
its eventual names and values after quantization.
"""

def group_quantize(
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None
):
if target is None or target.kind.name == "llvm":
target = tvm.target.Target("llvm")
device = tvm.cpu()
elif target.kind.name == "cuda":
device = tvm.cuda()
else:
raise ValueError(f"Invalid target device: {target}")
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param")
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore
param_tensor, config
)
s = tvm.te.create_schedule( # pylint: disable=invalid-name
[compute.op for compute in [weight_compute, scale_compute] + other_computes]
)
if target.kind.name == "cuda":
# thread_binding for cuda
for compute in [weight_compute, scale_compute] + other_computes:
xo, xi = s[compute].split(compute.op.axis[0], 256) # pylint: disable=invalid-name
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x"))
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x"))
f_quantize = tvm.build(
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target
)
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device)
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device)
f_quantize(param.copyto(device), weight, scale)
return weight, scale

# Param check
assert (
quantize_config.kind == "group_quantize"
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}"
assert (
quantize_config.name == "q4f16_1"
), """Only support q4f16_1 quantization scheme for now."""

# Fetch model parameter & names
model = LlamaForCasualLM(model_config)
_, named_params = model.export_tvm(spec=model.get_default_spec())
parameter_names = {name for name, _ in named_params}

# Init mappings
param_map: Dict[str, List[str]] = {}
map_func: Dict[str, Callable] = {}

# Dispatch quantization scheme
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py
for name in parameter_names:
if "norm.weight" not in name and "embed" not in name:
param_map[name] = [f"{name}_quantized", f"{name}_scale"]
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target)
else:
# skip these parameters
param_map[name] = [name]
map_func[name] = lambda x: [x]

return QuantizeMapping(param_map, map_func)
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Llama2 model using group quantization."""
model: nn.Module = LlamaForCasualLM(model_config)
quant_map = QuantizeMapping({}, {})
model = quantization.quantize_model(
model,
quant_map,
"model",
)
return model, quant_map
21 changes: 12 additions & 9 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""A centralized registry of all existing model architures and their configurations."""
import dataclasses
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Tuple

from tvm.relax.frontend import nn

from ..parameter import ExternMapping, QuantizeMapping
from ..quantization.quantization import QuantizeConfig
from . import llama_config, llama_model, llama_parameter
from ..quantization.quantization import Quantization
from . import llama_config, llama_model, llama_parameter, llama_quantization

ModelConfig = Any
"""A ModelConfig is an object that represents a model architecture. It is required to have
Expand All @@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig:
...
"""

FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]
FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping]
FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]]


@dataclasses.dataclass
Expand All @@ -38,15 +38,16 @@ class Model:
source : Dict[str, FuncGetExternMap]
A dictionary that maps the name of a source format to parameter mapping.

quantize: Dict[str, FuncGetQuantMap]
A dictionary that maps the name of a quantization method to quantization mapping.
quantize: Dict[str, FuncQuantization]
A dictionary that maps the name of a quantization method to quantized model and the
quantization parameter mapping.
"""

name: str
config: ModelConfig
model: Callable[[ModelConfig], nn.Module]
source: Dict[str, FuncGetExternMap]
quantize: Dict[str, FuncGetQuantMap]
quantize: Dict[str, FuncQuantization]


MODELS: Dict[str, Model] = {
Expand All @@ -58,7 +59,9 @@ class Model:
"huggingface-torch": llama_parameter.huggingface,
"huggingface-safetensor": llama_parameter.huggingface,
},
quantize={},
quantize={
"group-quant": llama_quantization.group_quant,
},
)
}

Expand Down
3 changes: 2 additions & 1 deletion python/mlc_chat/compiler/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""A subpackage for quantization and dequantization algorithms"""
from .quantization import QUANT, QuantizeConfig
from .group_quantization import GroupQuantize
from .quantization import QUANTIZATION, Quantization
Loading