Skip to content

Commit

Permalink
Quantized embedding (pytorch#536)
Browse files Browse the repository at this point in the history
* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 5e266fb commit d69915a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 114 deletions.
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000
import time

try:
import lm_eval
Expand Down
125 changes: 125 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from build.utils import (
find_multiple,
get_device_str,
get_precision,
name_to_dtype,
state_dict_device,
use_et_backend,
)
from torch.nn.parameter import Parameter


Expand Down Expand Up @@ -84,3 +93,119 @@ def __init__(

def forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8(input, self.weight, self.scales)


class QuantizedEmbedding(torch.nn.Module):
def __init__(
self,
num_embeddings: int, # vocab_size: int,
embedding_dim: int,
device=None,
dtype=None,
*,
bitwidth: int,
groupsize: Optional[int] = None,
) -> None:
super().__init__()
if dtype is None:
dtype = torch.half

if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.bitwidth = bitwidth

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if bitwidth == 8:
self.register_buffer(
"weight",
torch.empty(
(num_embeddings, embedding_dim), dtype=torch.int8, device=device
),
)
elif bitwidth == 4: # packed
self.register_buffer(
"weight",
torch.empty(
(num_embeddings, embedding_dim // 2),
dtype=torch.uint8,
device=device,
),
)
else:
raise RuntimeError(
f"QUantized embedding does not support bitwidth={bitwidth}"
)

groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(num_embeddings, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales",
torch.ones((num_embeddings,), dtype=torch.float16, device=device),
)

@torch.no_grad()
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.bitwidth == 8:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

if self.bitwidth == 4:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.to(torch.int8).add(-8)
else:
weight = self.weight

scales = self.scales.view(weight.shape[0], -1)

result_weights = F.embedding(indices, weight)
result_scales = F.embedding(indices, scales)

rw_view = result_weights.to(dtype=result_scales.dtype).view(
tuple(
result_weights.shape[:-1]
+ (
scales.shape[1],
-1,
)
)
)
rs_view = result_scales.view(
tuple(result_scales.shape[:-1])
+ (
scales.shape[1],
1,
)
)
# print(f"rw_view {rw_view.shape}")
# print(f"rs_view {rs_view.shape}")

r = rw_view * rs_view
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
116 changes: 3 additions & 113 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
state_dict_device,
use_et_backend,
)
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

from qops import LinearInt8 as WeightOnlyInt8Linear

#########################################################################
### torchchat quantization API ###
Expand Down Expand Up @@ -489,9 +489,9 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
setattr(
module,
name,
QuantizedGroupEmbedding(
QuantizedEmbedding(
device=device,
vocab_size=child.weight.shape[0],
num_embeddings=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
bitwidth=bitwidth,
groupsize=groupsize,
Expand Down Expand Up @@ -586,116 +586,6 @@ def quantized_model(self) -> nn.Module:
return self.model_


class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
device,
vocab_size: int,
embedding_dim: int,
bitwidth: int,
groupsize: Optional[int] = None,
*,
dtype=torch.half,
) -> None:
super().__init__()
if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.bitwidth = bitwidth

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if bitwidth == 8:
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim), dtype=torch.int8, device=device
),
)
elif bitwidth == 4: # packed
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
),
)
else:
raise RuntimeError(
f"QUantized embedding does not support bitwidth={bitwidth}"
)

groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(vocab_size, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
)

@torch.no_grad()
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.bitwidth == 8:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

if self.bitwidth == 4:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.to(torch.int8).add(-8)
else:
weight = self.weight

scales = self.scales.view(weight.shape[0], -1)

result_weights = F.embedding(indices, weight)
result_scales = F.embedding(indices, scales)

rw_view = result_weights.to(dtype=result_scales.dtype).view(
tuple(
result_weights.shape[:-1]
+ (
scales.shape[1],
-1,
)
)
)
rs_view = result_scales.view(
tuple(result_scales.shape[:-1])
+ (
scales.shape[1],
1,
)
)
# print(f"rw_view {rw_view.shape}")
# print(f"rs_view {rs_view.shape}")

r = rw_view * rs_view
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))


#########################################################################
##### weight only int4 per channel groupwise quantized code ######

Expand Down

0 comments on commit d69915a

Please sign in to comment.