Skip to content

Commit

Permalink
dtype updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 10, 2024
1 parent d6c2bc1 commit d5015b9
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dynamically_quantize_per_channel(



def get_group_qparams(w, n_bit=4, groupsize=128):
def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype torch.float):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
Expand All @@ -190,15 +190,15 @@ def get_group_qparams(w, n_bit=4, groupsize=128):
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
return scales.to(scales_dtype).reshape(w.shape[0], -1), zeros.to(
scales_dtype
).reshape(w.shape[0], -1)


def pack_scales_and_zeros(scales, zeros):
def pack_scales_and_zeros(scales, zeros, *, scales_dtype=torch.float):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
assert scales.dtype == scales_dtype
assert zeros.dtype == scales_dtype
return (
torch.cat(
[
Expand Down Expand Up @@ -658,7 +658,7 @@ def create_quantized_state_dict(self):
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
weight.to(torch.float), self.groupsize, self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
Expand Down Expand Up @@ -705,13 +705,15 @@ def __init__(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
# MKG: torch.float
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.float)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
# MKG torch.float
input = input.to(torch.float)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
Expand Down

0 comments on commit d5015b9

Please sign in to comment.