diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py old mode 100644 new mode 100755 index 5f008ee439..b4171102d2 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -89,7 +89,7 @@ def all_linear(mod, name): with torch.no_grad(): result = evaluate( HFLM( - pretrained=model.to(device), + pretrained=model, tokenizer=tokenizer, batch_size=batch_size, max_length=max_length), diff --git a/scripts/prepare.sh b/scripts/prepare.sh old mode 100644 new mode 100755 diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 418e75d039..025f36ec39 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -335,8 +335,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype], layout_type: LayoutType, + scale_dtype: Optional[torch.dtype] = None, ): if target_dtype in FP8_TYPES: diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 3005cb16a9..f13fb5bf55 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: input = ( input.contiguous() ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except: + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 9e49a689ed..cde651fa07 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -479,12 +479,12 @@ def from_float(cls, weight): class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ - AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight """ @classmethod def from_float(cls, weight): block_size = (1, weight.shape[1]) - return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -494,12 +494,13 @@ def from_float(cls, weight): # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, + AQFloat8WeightOnlyQuantizedLinearWeight, ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, - AQInt4G64WeightOnlyQuantizedLinearWeight + AQInt4G64WeightOnlyQuantizedLinearWeight, ] def _change_linears_to_autoquantizable(model, **kwargs):