From bbc8dcd8c0cc5973936618564d8a59514fb87b12 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 15 Jan 2025 11:14:24 -0800 Subject: [PATCH] Fix bug & apply review recommendations --- test/integration/test_integration.py | 3 +-- torchao/quantization/quant_primitives.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 7e9787f07f..4d39c4d9ae 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -10,6 +10,7 @@ import logging import os import unittest +from functools import partial import torch import torch.nn as nn @@ -1004,8 +1005,6 @@ def _test_lin_weight_subclass_api_impl( def test_int8_dynamic_quant_subclass_api( self, device, dtype, act_mapping, weight_zero_point_domain ): - from functools import partial - if ( not TORCH_VERSION_AT_LEAST_2_5 and dtype in (torch.float16, torch.bfloat16) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 61b508bdc0..949afc968f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -372,7 +372,7 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], - quant_dtype: Optional[torch.dtype], + quant_dtype: torch.dtype, zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """