Skip to content

Commit 26bb079

Browse files
authored
Merge branch 'pytorch:main' into feat/blockwise_fp8_quant_triton_gemm_ker
2 parents 31ed7c1 + f343336 commit 26bb079

33 files changed

+424
-431
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.4.0
5+
rev: v5.0.0
66
hooks:
77
- id: trailing-whitespace
88
- id: end-of-file-fixer

benchmarks/float8/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def profiler_output_to_filtered_time_by_kernel_name(
8383
continue
8484
elif e.key == "Activity Buffer Request":
8585
continue
86+
elif e.key == "Unrecognized":
87+
# TODO I think these are nvjet related
88+
continue
8689

8790
kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
8891
return kernel_name_to_gpu_time_us

test/dtypes/test_affine_quantized.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
to_affine_quantized_intx_static,
2626
)
2727
from torchao.quantization import (
28+
GemliteUIntXWeightOnlyConfig,
2829
Int4WeightOnlyConfig,
2930
Int8DynamicActivationInt8WeightConfig,
3031
float8_weight_only,
@@ -36,7 +37,7 @@
3637
quantize_,
3738
)
3839
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
39-
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
40+
from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm
4041
from torchao.utils import (
4142
TORCH_VERSION_AT_LEAST_2_5,
4243
check_cpu_version,
@@ -176,7 +177,7 @@ def _apply(module, config_or_subclass_inserter):
176177

177178
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
178179
def test_register_new_dispatch(self):
179-
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
180+
from torchao.dtypes import AffineQuantizedTensor
180181
from torchao.dtypes.affine_quantized_tensor_ops import (
181182
deregister_aqt_quantized_linear_dispatch,
182183
register_aqt_quantized_linear_dispatch,
@@ -344,7 +345,7 @@ def test_alias(self, device, dtype):
344345
@common_utils.parametrize("device", ["cuda"])
345346
@common_utils.parametrize("dtype", [torch.bfloat16])
346347
@skip_if_no_cuda()
347-
def test_slice(self, device, dtype):
348+
def test_slice_int4wo(self, device, dtype):
348349
# in_feature not divisible by 1024
349350
# out_feature not divisible by 8
350351
# to test slice + padding for int4 weight only quantization
@@ -354,6 +355,20 @@ def test_slice(self, device, dtype):
354355
_ = dummy.weight.narrow(0, 0, 64)
355356
_ = dummy.weight.narrow(1, 0, 128)
356357

358+
@common_utils.parametrize("device", ["cuda"])
359+
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
360+
@skip_if_no_cuda()
361+
@skip_if_no_gemlite()
362+
def test_slice_gemlite(self, device, dtype):
363+
# in_feature not divisible by 1024
364+
# out_feature not divisible by 8
365+
# to test slice + padding for int4 weight only quantization
366+
dummy = nn.Linear(256, 512, dtype=dtype, device=device)
367+
quantize_(dummy, GemliteUIntXWeightOnlyConfig())
368+
# make sure these run without error
369+
_ = dummy.weight.narrow(0, 0, 64)
370+
_ = dummy.weight.narrow(1, 0, 128)
371+
357372
@common_utils.parametrize("device", ["cuda"])
358373
@common_utils.parametrize("dtype", [torch.bfloat16])
359374
def test_matmul(self, device, dtype):

test/quantization/test_qat.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14741474
@unittest.skipIf(
14751475
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14761476
)
1477-
@unittest.skip("Currently failing on sqnr")
14781477
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14791478
"""
14801479
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1493,7 +1492,9 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14931492
torch.manual_seed(seed)
14941493
x = m.example_inputs()
14951494

1496-
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1495+
quantizer = Int8DynActInt4WeightQATQuantizer(
1496+
groupsize=group_size, precision=dtype, scales_precision=dtype
1497+
)
14971498
prepared = quantizer.prepare(m)
14981499
prepared_out = prepared(*x)
14991500
converted = quantizer.convert(prepared)

test/quantization/test_quant_api.py

+12
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,18 @@ def test_ao_per_module_config_embedding_linear(self):
10051005
assert isinstance(model.emb.weight._layout, QDQLayout)
10061006
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)
10071007

1008+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1009+
def test_ao_per_module_config_skip(self):
1010+
config1 = Int4WeightOnlyConfig(group_size=32)
1011+
config = AOPerModuleConfig({"_default": config1, "linear2": None})
1012+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
1013+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
1014+
quantize_(model, config)
1015+
model(*example_inputs)
1016+
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
1017+
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
1018+
assert not isinstance(model.linear2.weight, AffineQuantizedTensor)
1019+
10081020

10091021
class TestMultiTensorFlow(TestCase):
10101022
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

0 commit comments

Comments
 (0)