Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLM] Fix group quantization #1172

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 268 additions & 20 deletions python/mlc_chat/compiler/quantization/group_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple

import numpy as np
from tvm import DataType, DataTypeCode, device
from tvm import dlight as dl
from tvm import relax, te, tir
Expand Down Expand Up @@ -50,7 +51,25 @@ def quantize_model(
quant_map: QuantizeMapping,
name_prefix: str,
) -> nn.Module:
"""Quantize model with group quantization"""
"""
Quantize model with group quantization

Parameters
----------
model : nn.Module
The non-quantized nn.Module.

quant_map : QuantizeMapping
The quantize mapping with name mapping and func mapping.

name_prefix : str
The name prefix for visited weight.

Returns
-------
ret : nn.Module
The quantized nn.Module.
"""

class _Mutator(nn.Mutator):
def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None:
Expand All @@ -59,11 +78,37 @@ def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None:
self.quant_map = quant_map

def visit_module(self, name: str, node: nn.Module) -> Any:
"""
The visiting method for group quantization of nn.Module nodes.

Parameters
----------
name : str
The name of the current node.

node : nn.Module
The current node of nn.Module to mutate.

Returns
------
ret_node: Any
The new node to replace current node.
"""
if isinstance(node, nn.Linear):
weight_name = f"{name}.weight"
self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"]
# self.quant_map.map_func[weight_name] = self.config.quantize
self.quant_map.map_func[weight_name] = self.config.quantize_weight
return GroupQuantizeLinear.from_linear(node, self.config)
if isinstance(node, nn.MultiLinear):
weight_name = f"{name}.weight"
self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"]
self.quant_map.map_func[weight_name] = self.config.quantize_weight
return GroupQuantizeMultiLinear.from_multilinear(node, self.config)
if isinstance(node, nn.Embedding):
weight_name = f"{name}.weight"
self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"]
self.quant_map.map_func[weight_name] = self.config.quantize_weight
return GroupQuantizeEmbedding.from_embedding(node, self.config)
return self.visit(name, node)

model.to(dtype=self.model_dtype)
Expand All @@ -77,9 +122,7 @@ def _dequantize(
scale: te.Tensor,
out_shape: Optional[List[tir.PrimExpr]] = None,
):
quantize_dtype = DataType(self.quantize_dtype)
storage_dtype = DataType(self.storage_dtype)
tir_bin_mask = tir.const((2**quantize_dtype.bits) - 1, self.storage_dtype)
tir_bin_mask = tir.const((1 << DataType(self.quantize_dtype).bits) - 1, self.storage_dtype)
tir_max_int = tir.const(self.max_int_value, self.model_dtype)
dequantized_weight = te.compute(
shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage]
Expand All @@ -90,7 +133,11 @@ def _dequantize(
tir.bitwise_and(
tir.shift_right(
weight[i, j // self.num_elem_per_storage],
(j % self.num_elem_per_storage) * storage_dtype.bits,
tir.Cast(
self.storage_dtype,
(j % self.num_elem_per_storage)
* DataType(self.quantize_dtype).bits,
),
),
tir_bin_mask,
),
Expand All @@ -102,15 +149,27 @@ def _dequantize(
return dequantized_weight

def quantize_weight(self, weight: NDArray) -> List[NDArray]:
"""Quantize weight with group quantization"""
"""
Quantize weight with group quantization

Parameters
----------
weight : NDArray
The original weight.

Returns
------
ret: List[NDArray]
The list of group quantized weights.
"""
assert weight.dtype == self.model_dtype
assert len(weight.shape) == 2
bb = relax.BlockBuilder()
bb = relax.BlockBuilder() # pylint: disable=invalid-name
weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, self.model_dtype))
with bb.function(name="quantize", params=[weight_var]):
with bb.dataflow():
lv = bb.emit_te(self._quantize, weight_var)
gv = bb.emit_output(lv)
lv = bb.emit_te(self._quantize, weight_var) # pylint: disable=invalid-name
gv = bb.emit_output(lv) # pylint: disable=invalid-name
bb.emit_func_output(gv)
mod = bb.get()
with Target("cuda"):
Expand All @@ -119,7 +178,7 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]:
)(mod)
ex = relax.build(mod, "cuda")
dev = device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
vm = relax.VirtualMachine(ex, dev) # pylint: disable=invalid-name
return vm["quantize"](weight)

def _quantize( # pylint: disable=too-many-locals
Expand Down Expand Up @@ -164,7 +223,7 @@ def _quantize( # pylint: disable=too-many-locals
)

# compute quantized weight per storage
r = te.reduce_axis((0, self.num_elem_per_storage), name="r")
r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name
num_storage = self.num_storage_per_group * num_group
quantized_weight_shape = (n, num_storage)
quantized_weight = te.compute(
Expand Down Expand Up @@ -195,20 +254,36 @@ def __init__( # pylint: disable=too-many-arguments
self.out_features = out_features
self.out_dtype = out_dtype
self.config = config
n_group = tir.ceildiv(in_features, config.group_size)
self.weight = nn.Parameter(
(out_features, n_group * config.num_elem_per_storage),
(out_features, tir.ceildiv(in_features, config.num_elem_per_storage)),
config.storage_dtype,
)
self.scale = nn.Parameter((out_features, n_group), config.model_dtype)
self.scale = nn.Parameter(
(out_features, tir.ceildiv(in_features, config.group_size)), config.model_dtype
)
if bias:
self.bias = nn.Parameter((out_features,), config.model_dtype)
else:
self.bias = None

@staticmethod
def from_linear(linear: nn.Linear, config: GroupQuantize):
"""Converts a non-quantized nn.Linear to a quantized GroupQuantizeLinear"""
def from_linear(linear: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear":
"""
Converts a non-quantized nn.Linear to a group quantized GroupQuantizeLinear

Parameters
----------
linear : nn.Linear
The non-quantized nn.Linear.

config : GroupQuantize
The group quantization config.

Returns
-------
ret : GroupQuantizeLinear
The group quantized GroupQuantizeLinear layer.
"""
return GroupQuantizeLinear(
in_features=linear.in_features,
out_features=linear.out_features,
Expand All @@ -217,21 +292,194 @@ def from_linear(linear: nn.Linear, config: GroupQuantize):
out_dtype=linear.out_dtype,
)

def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name,missing-docstring
def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
"""
Forward method for group quantized linear layer.

Parameters
----------
x : nn.Tensor
The input tensor.

Returns
-------
ret : nn.Tensor
The output tensor for the group quantized linear layer.
"""
w = nn.op.tensor_expr_op( # pylint: disable=invalid-name
lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access
weight,
scale,
[tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)],
),
name_hint="decode",
args=[self.weight, self.scale],
)
w = nn.op.permute_dims(w) # pylint: disable=invalid-name
x = nn.op.matmul(x, w, out_dtype=self.out_dtype)
if self.bias is not None:
x = x + self.bias
return x


class GroupQuantizeMultiLinear(nn.Module): # pylint: disable=too-many-instance-attributes
"""An nn.MultiLinear module with group quantization"""

def __init__( # pylint: disable=too-many-arguments
self,
in_features: int,
out_features: nn.Sequence[int],
config: GroupQuantize,
bias: bool = True,
out_dtype: Optional[str] = None,
):
assert len(out_features) > 0
self.total_out_features = sum(out_features)

super().__init__()
self.in_features = in_features
self.out_features = out_features
self.out_dtype = out_dtype
self.config = config
self.weight = nn.Parameter(
(self.total_out_features, tir.ceildiv(in_features, config.num_elem_per_storage)),
config.storage_dtype,
)
self.scale = nn.Parameter(
(self.total_out_features, tir.ceildiv(in_features, config.group_size)),
config.model_dtype,
)
if bias:
self.bias = nn.Parameter((self.total_out_features,), config.model_dtype)
else:
self.bias = None

@staticmethod
def from_multilinear(
multi_linear: nn.MultiLinear, config: GroupQuantize
) -> "GroupQuantizeMultiLinear":
"""
Converts a non-quantized nn.MultiLinear to a group quantized GroupQuantizeLinear

Parameters
----------
linear : nn.Linear
The non-quantized nn.Linear.

config : GroupQuantize
The group quantization config.

Returns
-------
ret : GroupQuantizeMultiLinear
The group quantized GroupQuantizeMultiLinear layer.
"""
return GroupQuantizeMultiLinear(
in_features=multi_linear.in_features,
out_features=multi_linear.out_features,
config=config,
bias=getattr(multi_linear, "bias", None) is not None,
out_dtype=multi_linear.out_dtype,
)

def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
"""
Forward method for multi linear layer.

Parameters
----------
x : Tensor
The input tensor.

Returns
-------
ret : Tensor
The output tensor for the multi linear layer.
"""
sections = list(np.cumsum(self.out_features)[:-1])
w = nn.op.tensor_expr_op( # pylint: disable=invalid-name
lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access
weight,
scale,
[
tir.IntImm("int64", self.out_features),
tir.IntImm("int64", self.total_out_features),
tir.IntImm("int64", self.in_features),
],
),
name_hint="decode",
args=[self.weight, self.scale],
)
# x: [*B, in_features]
# w: [in_features, out_features]
w = nn.op.permute_dims(w) # pylint: disable=invalid-name
# x: [*B, out_features]
x = nn.op.matmul(x, w, out_dtype=self.out_dtype)
if self.bias is not None:
x = x + self.bias
return x
results = nn.op.split(x, sections, axis=-1)
return results


class GroupQuantizeEmbedding(nn.Module):
"""An nn.Embedding module with group quantization"""

def __init__(self, num: int, dim: int, config: GroupQuantize):
self.num = num
self.dim = dim
self.config = config
n_group = tir.ceildiv(dim, config.group_size)
self.weight = nn.Parameter(
(num, n_group * config.num_elem_per_storage), config.storage_dtype
)
self.scale = nn.Parameter((num, n_group), config.model_dtype)

@staticmethod
def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding":
"""
Converts a non-quantized nn.Embedding to a group quantized GroupQuantizeEmbedding

Parameters
----------
linear : nn.Embedding
The non-quantized nn.Embedding.

config : GroupQuantize
The group quantization config.

Returns
-------
ret : GroupQuantizeEmbedding
The group quantized GroupQuantizeEmbedding layer.
"""
num, dim = embedding.weight.shape
return GroupQuantizeEmbedding(num, dim, config)

def forward(self, x: nn.Tensor): # pylint: disable=invalid-name
"""
Forward method for group quantized embedding layer.

Parameters
----------
x : nn.Tensor
The input tensor.

Returns
-------
ret : nn.Tensor
The output tensor for the embedding layer.
"""
w = nn.op.tensor_expr_op( # pylint: disable=invalid-name
lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access
weight,
scale,
[tir.IntImm("int64", self.num), tir.IntImm("int64", self.dim)],
),
name_hint="decode",
args=[self.weight, self.scale],
)
if x.ndim == 1:
return nn.op.take(w, x, axis=0)
return nn.op.reshape(
nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0),
shape=[*x.shape, self.dim],
)
Loading