Skip to content

Commit

Permalink
chore: update test
Browse files Browse the repository at this point in the history
  • Loading branch information
Diogo-V committed Sep 5, 2024
1 parent a3bfd80 commit e78d93f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
from torchao.dtypes import MarlinSparseLayoutType
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
Expand Down Expand Up @@ -55,7 +55,6 @@ def test_quant_sparse_marlin_layout_eager(self):

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
Expand All @@ -68,6 +67,9 @@ def test_quant_sparse_marlin_layout_compile(self):

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(self.model)

self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

Expand Down

0 comments on commit e78d93f

Please sign in to comment.