25
25
to_affine_quantized_intx_static ,
26
26
)
27
27
from torchao .quantization import (
28
+ GemliteUIntXWeightOnlyConfig ,
28
29
Int4WeightOnlyConfig ,
29
30
Int8DynamicActivationInt8WeightConfig ,
30
31
float8_weight_only ,
36
37
quantize_ ,
37
38
)
38
39
from torchao .quantization .quant_primitives import MappingType , ZeroPointDomain
39
- from torchao .testing .utils import skip_if_no_cuda , skip_if_rocm
40
+ from torchao .testing .utils import skip_if_no_cuda , skip_if_no_gemlite , skip_if_rocm
40
41
from torchao .utils import (
41
42
TORCH_VERSION_AT_LEAST_2_5 ,
42
43
check_cpu_version ,
@@ -176,7 +177,7 @@ def _apply(module, config_or_subclass_inserter):
176
177
177
178
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
178
179
def test_register_new_dispatch (self ):
179
- from torchao .dtypes import AffineQuantizedTensor , to_affine_quantized_intx
180
+ from torchao .dtypes import AffineQuantizedTensor
180
181
from torchao .dtypes .affine_quantized_tensor_ops import (
181
182
deregister_aqt_quantized_linear_dispatch ,
182
183
register_aqt_quantized_linear_dispatch ,
@@ -344,7 +345,7 @@ def test_alias(self, device, dtype):
344
345
@common_utils .parametrize ("device" , ["cuda" ])
345
346
@common_utils .parametrize ("dtype" , [torch .bfloat16 ])
346
347
@skip_if_no_cuda ()
347
- def test_slice (self , device , dtype ):
348
+ def test_slice_int4wo (self , device , dtype ):
348
349
# in_feature not divisible by 1024
349
350
# out_feature not divisible by 8
350
351
# to test slice + padding for int4 weight only quantization
@@ -354,6 +355,20 @@ def test_slice(self, device, dtype):
354
355
_ = dummy .weight .narrow (0 , 0 , 64 )
355
356
_ = dummy .weight .narrow (1 , 0 , 128 )
356
357
358
+ @common_utils .parametrize ("device" , ["cuda" ])
359
+ @common_utils .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
360
+ @skip_if_no_cuda ()
361
+ @skip_if_no_gemlite ()
362
+ def test_slice_gemlite (self , device , dtype ):
363
+ # in_feature not divisible by 1024
364
+ # out_feature not divisible by 8
365
+ # to test slice + padding for int4 weight only quantization
366
+ dummy = nn .Linear (256 , 512 , dtype = dtype , device = device )
367
+ quantize_ (dummy , GemliteUIntXWeightOnlyConfig ())
368
+ # make sure these run without error
369
+ _ = dummy .weight .narrow (0 , 0 , 64 )
370
+ _ = dummy .weight .narrow (1 , 0 , 128 )
371
+
357
372
@common_utils .parametrize ("device" , ["cuda" ])
358
373
@common_utils .parametrize ("dtype" , [torch .bfloat16 ])
359
374
def test_matmul (self , device , dtype ):
0 commit comments