Skip to content

Commit

Permalink
[SLIM][AWQ] AWQ GEMM support (#1362)
Browse files Browse the repository at this point in the history
Previously #1229, we only supported loading INT4-GEMV-awq weights quantized by [the original repo](https://github.com/mit-han-lab/llm-awq#usage). This pr supports loading INT4-GEMM format weights quantized by [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). Here is an example that loads from [TheBloke/Llama-2-7b-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) and runs the benchmark.

```bash
MODEL_PATH=/opt/models/llama-2/llama-2-7b-chat-hf/
OUTPUT_PATH=./dist/new-llama-awq/
QUANT=q4f16_autoawq

python -m mlc_chat gen_config $MODEL_PATH/config.json --quantization $QUANT -o $OUTPUT_PATH --conv-template llama-2

python -m mlc_chat compile $OUTPUT_PATH -o $OUTPUT_PATH/llama.so

python -m mlc_chat convert_weight $MODEL_PATH --quantization $QUANT -o $OUTPUT_PATH --source-format awq --source ../Llama-2-7B-Chat-AWQ/model.safetensors

python -m mlc_chat.cli.benchmark --model $OUTPUT_PATH --model-lib $OUTPUT_PATH/llama.so --device "cuda:0" --prompt "What is the meaning of life?" --generate-length 256
```

Note:
Q: What is difference between INT4-GEMV and INT4-GEMM?
A: https://github.com/casper-hansen/AutoAWQ#int4-gemm-vs-int4-gemv-vs-fp16
  • Loading branch information
LeshengJin authored Dec 18, 2023
1 parent 76a77bc commit 26d348a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
10 changes: 8 additions & 2 deletions python/mlc_chat/compiler/model/llama/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
f"{attn}.v_proj.{quantize_suffix}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
lambda q, k, v, dtype: np.concatenate(
[q, k, v],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)
Expand All @@ -140,7 +143,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
f"{mlp}.up_proj.{quantize_suffix}",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
lambda gate, up, dtype: np.concatenate(
[gate, up],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)
Expand Down
28 changes: 12 additions & 16 deletions python/mlc_chat/compiler/quantization/awq_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional

from tvm import DataType, DataTypeCode, te, tir
from tvm import DataType, DataTypeCode, te, tir, topi
from tvm.relax.frontend import nn
from tvm.runtime import NDArray

Expand Down Expand Up @@ -138,16 +138,21 @@ def _dequantize(
self.num_elem_per_storage,
self.storage_dtype,
self.model_dtype,
out_shape,
[weight.shape[0], weight.shape[1] * self.num_elem_per_storage],
ft_reorder=True,
)
float_zeros = convert_uint_to_float(
zeros,
DataType(self.quantize_dtype).bits,
self.num_elem_per_storage,
self.storage_dtype,
self.model_dtype,
out_shape,
[zeros.shape[0], zeros.shape[1] * self.num_elem_per_storage],
ft_reorder=True,
)
float_weight = topi.transpose(float_weight)
float_zeros = topi.transpose(float_zeros)
scale = topi.transpose(scale)
return te.compute(
shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage]
if out_shape is None
Expand Down Expand Up @@ -177,23 +182,14 @@ def __init__( # pylint: disable=too-many-arguments
self.out_dtype = out_dtype
self.config = config
self.qweight = nn.Parameter(
(out_features, tir.ceildiv(in_features, config.num_elem_per_storage)),
config.storage_dtype,
(in_features, out_features // config.num_elem_per_storage), config.storage_dtype
)
self.qzeros = nn.Parameter(
(
out_features,
_calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage),
),
dtype=config.storage_dtype,
(in_features // config.group_size, out_features // config.num_elem_per_storage),
config.storage_dtype,
)
self.scales = nn.Parameter(
(
out_features,
_calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage)
* config.num_elem_per_storage,
),
config.model_dtype,
(in_features // config.group_size, out_features), config.model_dtype
)
if bias:
self.bias = nn.Parameter(
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/compiler/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
storage_dtype="uint32",
model_dtype="float32",
),
"q4f16_awq": AWQQuantize(
name="q4f16_awq",
"q4f16_autoawq": AWQQuantize(
name="q4f16_autoawq",
kind="awq",
group_size=128,
quantize_dtype="int4",
Expand Down
7 changes: 6 additions & 1 deletion python/mlc_chat/compiler/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments
storage_dtype: str,
model_dtype: str,
out_shape: Optional[List[tir.PrimExpr]] = None,
ft_reorder: Optional[bool] = False,
) -> te.Tensor:
"""Convert a quantized uint weight to an unquantized float weight."""
tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype)
Expand All @@ -21,7 +22,11 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments
fcompute=lambda i, j: tir.bitwise_and(
tir.shift_right(
weight[i, j // num_elem_per_storage],
((j % num_elem_per_storage) * bits).astype(storage_dtype),
(
((j % num_elem_per_storage) % 2 * 4 + (j % num_elem_per_storage) // 2) * bits
if ft_reorder
else (j % num_elem_per_storage) * bits
).astype(storage_dtype),
),
tir_bin_mask,
).astype(model_dtype),
Expand Down

0 comments on commit 26d348a

Please sign in to comment.