From 0ba6a2c5b4c4f132e9107a4e63e794bce32b4461 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 10 Sep 2024 15:19:37 -0700 Subject: [PATCH] Float8 autoquant weight only --- scripts/hf_eval.py | 2 +- scripts/prepare.sh | 0 torchao/quantization/autoquant.py | 13 ++++++++++--- 3 files changed, 11 insertions(+), 4 deletions(-) mode change 100755 => 100644 scripts/hf_eval.py mode change 100755 => 100644 scripts/prepare.sh diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py old mode 100755 new mode 100644 index b4171102d2..5f008ee439 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -89,7 +89,7 @@ def all_linear(mod, name): with torch.no_grad(): result = evaluate( HFLM( - pretrained=model, + pretrained=model.to(device), tokenizer=tokenizer, batch_size=batch_size, max_length=max_length), diff --git a/scripts/prepare.sh b/scripts/prepare.sh old mode 100755 new mode 100644 diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index cde651fa07..fa4ca36d85 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -479,12 +479,19 @@ def from_float(cls, weight): class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ - AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn """ + target_dtype: torch.dtype = torch.float8_e4m3fn + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias) + @classmethod def from_float(cls, weight): block_size = (1, weight.shape[1]) - return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -500,7 +507,7 @@ def from_float(cls, weight): DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, - AQInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight ] def _change_linears_to_autoquantizable(model, **kwargs):