Skip to content

Commit

Permalink
Float8 autoquant weight only
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 17, 2024
1 parent 8e80b39 commit b8ab4ee
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/hf_eval.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def all_linear(mod, name):
with torch.no_grad():
result = evaluate(
HFLM(
pretrained=model.to(device),
pretrained=model,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
Expand Down
Empty file modified scripts/prepare.sh
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def from_hp_to_floatx(
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
scale_dtype: Optional[torch.dtype],
layout_type: LayoutType,
scale_dtype: Optional[torch.dtype] = None,
):

if target_dtype in FP8_TYPES:
Expand Down
5 changes: 4 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except:
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down
7 changes: 4 additions & 3 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,12 @@ def from_float(cls, weight):

class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType())
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType())

# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -494,12 +494,13 @@ def from_float(cls, weight):
# AQInt8WeightOnlyQuantizedLinearWeight3,
# TODO this gets picked in places where it makes perf worse, why?
AQInt8DynamicallyQuantizedLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
]

DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQInt4G64WeightOnlyQuantizedLinearWeight
AQInt4G64WeightOnlyQuantizedLinearWeight,
]

def _change_linears_to_autoquantizable(model, **kwargs):
Expand Down

0 comments on commit b8ab4ee

Please sign in to comment.