Skip to content

Commit

Permalink
Float8 autoquant weight only
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 20, 2024
1 parent b8ab4ee commit 0ba6a2c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scripts/hf_eval.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def all_linear(mod, name):
with torch.no_grad():
result = evaluate(
HFLM(
pretrained=model,
pretrained=model.to(device),
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
Expand Down
Empty file modified scripts/prepare.sh
100755 → 100644
Empty file.
13 changes: 10 additions & 3 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,19 @@ def from_float(cls, weight):

class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
"""
target_dtype: torch.dtype = torch.float8_e4m3fn

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)

@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType())
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -500,7 +507,7 @@ def from_float(cls, weight):
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQInt4G64WeightOnlyQuantizedLinearWeight,
AQInt4G64WeightOnlyQuantizedLinearWeight
]

def _change_linears_to_autoquantizable(model, **kwargs):
Expand Down

0 comments on commit 0ba6a2c

Please sign in to comment.