16
16
_int4_symm_cutlass_quant ,
17
17
_int8_symm_cutlass_quant ,
18
18
)
19
+ from torchao .testing .utils import get_compute_capability
19
20
20
21
DTYPES = [torch .float16 , torch .bfloat16 ]
21
22
BATCH_SIZE = [1 , 4 , 8 , 16 , 32 , 64 ]
@@ -87,6 +88,7 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias):
87
88
88
89
89
90
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
91
+ @pytest .mark .skipif (get_compute_capability () != 8.0 , reason = "Only supported on A100" )
90
92
@pytest .mark .parametrize ("dtype, batch_size, size_mnk, use_bias" , TEST_PARAMS )
91
93
def test_rowwise_scaled_linear_cutlass_s4s4 (dtype , batch_size , size_mnk , use_bias ):
92
94
run_test_for_op (
@@ -99,6 +101,7 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia
99
101
100
102
101
103
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
104
+ @pytest .mark .skipif (get_compute_capability () != 8.0 , reason = "Only supported on A100" )
102
105
@pytest .mark .parametrize ("dtype, batch_size, size_mnk, use_bias" , TEST_PARAMS )
103
106
def test_rowwise_scaled_linear_cutlass_s8s4 (dtype , batch_size , size_mnk , use_bias ):
104
107
run_test_for_op (
0 commit comments