Skip to content

Commit 5a78b70

Browse files
authored
Skip failing tests for rowwise-scaled (#2022)
stack-info: PR: #2022, branch: drisspg/stack/46
1 parent 6922733 commit 5a78b70

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

test/test_ops_rowwise_scaled_linear_cutlass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_int4_symm_cutlass_quant,
1717
_int8_symm_cutlass_quant,
1818
)
19+
from torchao.testing.utils import get_compute_capability
1920

2021
DTYPES = [torch.float16, torch.bfloat16]
2122
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
@@ -87,6 +88,7 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias):
8788

8889

8990
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
91+
@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100")
9092
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
9193
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
9294
run_test_for_op(
@@ -99,6 +101,7 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia
99101

100102

101103
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
104+
@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100")
102105
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
103106
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
104107
run_test_for_op(

0 commit comments

Comments
 (0)