Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip tests broken by change of _convert_weight_to_int4pack #504

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
PerChannelSymmetricWeightUInt4Tensor,
)
import unittest
from unittest import TestCase, main
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch._export import dynamic_dim
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
)
from torchao.quantization.utils import (
compute_error,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
Expand All @@ -30,7 +25,6 @@
QuantizationAnnotation,
)
import copy
from packaging import version


def _apply_weight_only_uint4_quant(model):
Expand Down Expand Up @@ -229,4 +223,4 @@ def forward(self, x):
)

if __name__ == "__main__":
main()
unittest.main()
8 changes: 6 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
unwrap_tensor_subclass,
is_fbcode,
benchmark_model
Expand Down Expand Up @@ -734,6 +735,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -744,6 +746,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1020,7 +1023,8 @@ def test_save_load_int8woqtensors(self, device, dtype):
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
@torch.no_grad()
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
Expand Down Expand Up @@ -1500,7 +1504,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):


class TestBenchmarkModel(unittest.TestCase):

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
Expand Down
2 changes: 2 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
)
from pathlib import Path
from torchao._models.llama.tokenizer import get_tokenizer
Expand Down Expand Up @@ -522,6 +523,7 @@ def test_quantized_tensor_subclass_8da4w(self):
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def __tensor_unflatten__(

@classmethod
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_min = 0
quant_max = 2 ** n_bit - 1

return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
return int_data

def groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
Expand Down
Loading