Skip to content

Commit e78d93f

Browse files
committed
chore: update test
1 parent a3bfd80 commit e78d93f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/sparsity/test_marlin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase, run_tests
7-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
88
from torchao.dtypes import MarlinSparseLayoutType
99
from torchao.sparsity.sparse_api import apply_fake_sparsity
1010
from torchao.quantization.quant_api import int4_weight_only, quantize_
@@ -55,7 +55,6 @@ def test_quant_sparse_marlin_layout_eager(self):
5555

5656
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
5757

58-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5")
5958
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
6059
def test_quant_sparse_marlin_layout_compile(self):
6160
apply_fake_sparsity(self.model)
@@ -68,6 +67,9 @@ def test_quant_sparse_marlin_layout_compile(self):
6867

6968
# Sparse + quantized
7069
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
70+
if not TORCH_VERSION_AT_LEAST_2_5:
71+
unwrap_tensor_subclass(self.model)
72+
7173
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7274
sparse_result = self.model(self.input)
7375

0 commit comments

Comments
 (0)