Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Int4-AWQ] Torch Int-4 AWQ Dequantization and Configuration Options #146

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/kernels/test_awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

from vllm.model_executor.layers.quantization.awq import torch_awq_dequantize
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton, awq_gemm_triton)

Expand Down Expand Up @@ -76,7 +77,7 @@ def awq_gemm_torch(input: torch.Tensor, qweight: torch.Tensor,
print(f"awq_gemm_torch:input_rows = {input_rows} input_cols = {input_cols}"
f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}"
f" scales_rows = {scales_rows} scales_cols = {scales_cols}")
weights, zeros = awq_dequantize_torch(qweight, scales, qzeros)
weights = torch_awq_dequantize(qweight, scales, qzeros)
return torch.matmul(input, weights)


Expand Down Expand Up @@ -123,7 +124,7 @@ def test_dequantize(qweight_rows, qweight_cols):
print("Any infs in triton result? -->"
f"{torch.any(torch.isinf(iweights_triton))}")

iweights_torch, _ = awq_dequantize_torch(qweight, scales, zeros)
iweights_torch = torch_awq_dequantize(qweight, scales, zeros)
print(f"Torch result:iweights_torch = {iweights_torch}")

diff = iweights_torch - iweights_triton
Expand Down
24 changes: 12 additions & 12 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
print(f"awq_dequantize:qweight.shape = {qweight.shape}"
f"scales = {scales.shape},"
f"zeros = {zeros.shape},"
f"split_k_iters = {split_k_iters},"
f"thx = {thx}"
f"thy = {thy}")
# print(f"awq_dequantize:qweight.shape = {qweight.shape}"
# f"scales = {scales.shape},"
# f"zeros = {zeros.shape},"
# f"split_k_iters = {split_k_iters},"
# f"thx = {thx}"
# f"thy = {thy}")
if is_hip() and envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
Expand All @@ -153,12 +153,12 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
if input.shape[0] > 1:
print(f"awq_gemm:input.shape = {input.shape},"
f"qweight = {qweight.shape},"
f"qzeros = {qzeros.shape},"
f"scales.shape = {scales.shape},"
f"split_k_iters = {split_k_iters}")
# if input.shape[0] > 1:
# print(f"awq_gemm:input.shape = {input.shape},"
# f"qweight = {qweight.shape},"
# f"qzeros = {qzeros.shape},"
# f"scales.shape = {scales.shape},"
# f"split_k_iters = {split_k_iters}")
if is_hip() and envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_gemm_triton)
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")

self.use_naive_attn = False
self.use_naive_attn = envs.VLLM_USE_SDPA_ATTENTION # Default False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
Expand All @@ -306,7 +306,7 @@ def __init__(

if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
logger.debug("Using naive (SDPA) attention in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
Expand Down
18 changes: 18 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_ROCM_PREFER_TORCH: bool = False
VLLM_ROCM_PREFER_TRITON: bool = True
VLLM_USE_SDPA_ATTENTION: bool = False
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
Expand Down Expand Up @@ -136,6 +139,21 @@
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),

# flag to tell vllm to prefer torch on ROCm
"VLLM_ROCM_PREFER_TORCH":
lambda: (os.environ.get("VLLM_ROCM_PREFER_TORCH", "False").lower() in
("true", "1")),

# flag to tell vllm to prefer triton on ROCm
"VLLM_ROCM_PREFER_TRITON":
lambda: (os.environ.get("VLLM_ROCM_PREFER_TRITON", "True").lower() in
("true", "1")),

# flag to control if vllm should use naive scaled dot-product attention
"VLLM_USE_SDPA_ATTENTION":
lambda: (os.environ.get("VLLM_USE_SDPA_ATTENTION", "False").lower() in
("true", "1")),

# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
Expand Down
53 changes: 51 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.parameter import Parameter

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -164,12 +165,60 @@ def apply(self,
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
prefer_torch = envs.VLLM_ROCM_PREFER_TORCH
prefer_triton = envs.VLLM_ROCM_PREFER_TRITON

if (FP16_MATMUL_HEURISTIC_CONDITION
or (prefer_torch and not prefer_triton)):
if prefer_triton:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
else:
out = torch_awq_dequantize(qweight, scales, qzeros)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)


def torch_awq_dequantize(qweights: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor) -> torch.Tensor:
reverse_awq_func_desc = torch.tensor([0, 16, 4, 20, 8, 24, 12, 28],
dtype=torch.int32,
device=qweights.device)
if qzeros is None:
qzeros = torch.zeros_like(qweights)

while qweights.dim() < 2:
qweights = torch.unsqueeze(qweights, 0)
while qzeros.dim() < 2:
qzeros = torch.unsqueeze(qzeros, 0)
while scales.dim() < 2:
scales = torch.unsqueeze(scales, 0)

rows = qweights.size(-2)
group_size_zeros = rows // qzeros.size(-2)
group_size_scales = rows // scales.size(-2)

qweights_shape = list(qweights.shape)
qweights_shape[-1] *= 8
qzeros_shape = list(qzeros.shape)
qzeros_shape[-1] *= 8

qweights = torch.unsqueeze(qweights, -1)
qzeros = torch.unsqueeze(qzeros, -1)

unpacked_weights = torch.bitwise_right_shift(qweights,
reverse_awq_func_desc)
unpacked_weights = torch.bitwise_and(unpacked_weights, 0xf)
unpacked_weights = unpacked_weights.to(torch.int8).view(qweights_shape)

unpacked_zeros = torch.bitwise_right_shift(qzeros, reverse_awq_func_desc)
unpacked_zeros = torch.bitwise_and(unpacked_zeros, 0xf)
unpacked_zeros = unpacked_zeros.to(torch.int8).view(qzeros_shape)
unpacked_zeros = unpacked_zeros.repeat_interleave(group_size_zeros, dim=-2)

functional_scales = scales.repeat_interleave(group_size_scales, dim=-2)
return (unpacked_weights - unpacked_zeros) * functional_scales
Loading