-
Notifications
You must be signed in to change notification settings - Fork 215
/
Copy pathtest_quant_api.py
828 lines (708 loc) · 31.3 KB
/
test_quant_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# mypy: ignore-errors
# This test takes a long time to run
import copy
import gc
import tempfile
import unittest
from pathlib import Path
import torch
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase
from torchao import quantize_
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import (
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
unwrap_tensor_subclass,
)
def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_dynamic=True)
)
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
return m
def capture_and_prepare(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True)
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_dynamic=True)
)
m = prepare_pt2e(m, quantizer)
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
m(*example_inputs)
return m
class XNNPackDynamicQuantizer(TwoStepQuantizer):
def prepare(self, model: torch.nn.Module) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: capture_and_prepare(
linear_mod, (torch.randn(1, linear_mod.in_features))
),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model
def convert(self, model: torch.nn.Module) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: convert_pt2e(linear_mod),
lambda mod, fqn: isinstance(mod, torch.fx.GraphModule),
)
return model
class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
quantize_(model, int8_dynamic_activation_int8_weight())
return model
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_in_features_greater_than_16,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
filter_fn,
)
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear
filter_fn = kwargs.pop("filter_fn", _is_linear)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
),
filter_fn,
)
return _ref_change_linear_weights_to_woqtensors
_ref_change_linear_weights_to_int8_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
)
_ref_change_linear_weights_to_int4_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
)
class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
quantize_(m, int8_dynamic_activation_int8_weight())
m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
# m = torch.compile(m, mode="max-autotune")
# print(example_inputs[0].dtype)
# compiled = m(*example_inputs)
# torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
@unittest.skip("skipping for now due to torch.compile error")
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
m = torch.compile(m, mode="max-autotune")
# print(example_inputs[0].dtype)
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
@unittest.skip(
"FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!"
)
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
quantized = m(*example_inputs)
m = torch.compile(m, mode="max-autotune")
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+")
def test_int8_wo_quant_save_load(self):
m = ToyLinearModel().eval().cpu()
def api(model):
quantize_(model, int8_weight_only())
unwrap_tensor_subclass(model)
api(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
m2 = ToyLinearModel().eval().cpu()
api(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
example_inputs = map(lambda x: x.cuda(), example_inputs)
res = m2(*example_inputs)
torch.testing.assert_close(ref, res.cpu())
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower"
)
def test_8da4w_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)
# TODO: save model weights as artifacts and re-enable in CI
# For now, to run this test, you will need to download the weights from HF
# and run this script to convert them:
# https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_8da4w_gptq_quantizer(self):
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
blocksize = 128
percdamp = 0.01
groupsize = 128
calibration_tasks = ["wikitext"]
calibration_limit = 1
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)
quantizer = Int8DynActInt4WeightGPTQQuantizer(
blocksize,
percdamp,
groupsize,
precision=precision,
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs)
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 7.88
), f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}"
@unittest.skip("skipping until we get checkpoints for gpt-fast")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower"
)
def test_8da4w_quantizer_eval(self):
from torchao._models._eval import TransformerEvalWrapper
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
precision = torch.bfloat16
device = "cpu"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision)
q_model = quantizer.quantize(model)
result = TransformerEvalWrapper(
q_model,
tokenizer,
q_model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 8.24
), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4_weight_only(self):
from torchao._models._eval import (
MultiTensorInputRecorder,
TransformerEvalWrapper,
)
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
blocksize = 128
percdamp = 0.01
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = 5
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
inputs = (
MultiTensorInputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs).cuda()
result = TransformerEvalWrapper(
model.cuda(),
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
None,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 7.77
), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4_weight_only(self):
from torchao._models._eval import TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
groupsize = 64
quantizer = Int4WeightOnlyQuantizer(
groupsize,
)
model = quantizer.quantize(model).cuda()
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 8.24
), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 7.77
), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
# EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper_llama3(self):
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path(
".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"
)
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Meta-Llama-3-8B",
)
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert (
result["results"]["wikitext"]["word_perplexity,none"] < 8.24
), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@common_utils.parametrize(
"mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]
)
def test_quantized_tensor_subclass_8da4w(self, mapping_type):
group_size = 32
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
quantize_(
m,
int8_dynamic_activation_int4_weight(
group_size=group_size, mapping_type=mapping_type
),
)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(
m.linear1.weight.original_weight_tensor, AffineQuantizedTensor
)
assert isinstance(
m.linear2.weight.original_weight_tensor, AffineQuantizedTensor
)
# reference
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
quantizer = Int8DynActInt4WeightQuantizer(
groupsize=group_size, mapping_type=mapping_type
)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)
res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
group_size = 32
quantize_(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
# reference
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)
res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_wo(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
quantize_(m, int8_weight_only())
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
# reference
_ref_change_linear_weights_to_int8_woqtensors(m_copy)
res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# use multiples of 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(
batch_size=20, dtype=torch.bfloat16, device="cuda"
)
quantize_(m, int8_dynamic_activation_int8_weight())
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(
m.linear1.weight.original_weight_tensor, AffineQuantizedTensor
)
assert isinstance(
m.linear2.weight.original_weight_tensor, AffineQuantizedTensor
)
# reference
_ref_change_linear_weights_to_int8_dqtensors(m_copy)
res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))
# workaround for export path
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)
m = torch.export.export(m_unwrapped, example_inputs, strict=True).module()
exported_model_res = m(*example_inputs)
self.assertTrue(torch.equal(exported_model_res, ref))
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)
quantize_(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
m_copy.load_state_dict(state_dict, assign=True)
res = m_copy(*example_inputs)
self.assertEqual(res, ref)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8wo_quantized_model_to_device(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
quantize_(m, int8_weight_only())
ref = m(*example_inputs)
example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+")
def test_int4wo_quantized_model_to_device(self):
# TODO: change initial model to "cpu"
devices = ["cuda", "cuda:0"]
for device in devices:
m = ToyLinearModel().eval().to(torch.bfloat16).to(device)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)
quantize_(m, int4_weight_only())
ref = m(*example_inputs)
example_inputs_cuda = (example_inputs[0].to(device),)
m.to(device=device)
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load_map_location(self):
m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda")
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
quantize_(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f.name, map_location="cpu", mmap=True)
with torch.device("meta"):
m_copy = ToyLinearModel().eval()
m_copy.load_state_dict(state_dict, assign=True)
m_copy.to(dtype=torch.bfloat16, device="cuda")
res = m_copy(*example_inputs)
self.assertEqual(res, ref)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_model_streaming(self):
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
reset_memory()
m = ToyLinearModel()
quantize_(m.to(device="cuda"), int8_weight_only())
memory_baseline = torch.cuda.max_memory_allocated()
del m
reset_memory()
m = ToyLinearModel()
quantize_(m, int8_weight_only(), device="cuda")
memory_streaming = torch.cuda.max_memory_allocated()
for param in m.parameters():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
def test_int4wo_cpu(self, dtype, x_dim):
from torchao.dtypes import Int4CPULayout
device = "cpu"
m = ToyLinearModel().eval().to(dtype).to(device)
example_inputs = m.example_inputs(dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
with torch.no_grad():
quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout()))
# ensure the expected op is in the code
_, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]
class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_add_tensors(self):
from torchao.quantization.GPTQ_MT import MultiTensor
tensor1 = torch.randn(3, 3)
tensor2 = torch.randn(3, 3)
mt = MultiTensor(tensor1)
mt.add_tensors(tensor2)
self.assertEqual(mt.count, 2)
self.assertTrue(torch.equal(mt.values[0], tensor1))
self.assertTrue(torch.equal(mt.values[1], tensor2))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_pad_unpad(self):
from torchao.quantization.GPTQ_MT import MultiTensor
tensor1 = torch.randn(3, 3)
mt = MultiTensor(tensor1)
mt.pad_to_length(3)
self.assertEqual(mt.count, 3)
mt.unpad()
self.assertEqual(mt.count, 1)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_multitensor_inplace_operation(self):
from torchao.quantization.GPTQ_MT import MultiTensor
tensor1 = torch.ones(3, 3)
mt = MultiTensor(tensor1)
mt += 1 # In-place addition
self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2)))
common_utils.instantiate_parametrized_tests(TestQuantFlow)
if __name__ == "__main__":
unittest.main()