diff --git a/test/test_ops_rowwise_scaled_linear_cutlass.py b/test/test_ops_rowwise_scaled_linear_cutlass.py index f9b9c6a7f9..72bb201b3f 100644 --- a/test/test_ops_rowwise_scaled_linear_cutlass.py +++ b/test/test_ops_rowwise_scaled_linear_cutlass.py @@ -16,6 +16,7 @@ _int4_symm_cutlass_quant, _int8_symm_cutlass_quant, ) +from torchao.testing.utils import get_compute_capability DTYPES = [torch.float16, torch.bfloat16] BATCH_SIZE = [1, 4, 8, 16, 32, 64] @@ -87,6 +88,7 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( @@ -99,6 +101,7 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op(