34
34
35
35
ACT_ORDER_OPTS = [False , True ]
36
36
K_FULL_OPTS = [False , True ]
37
+ USE_ATOMIC_ADD_OPTS = [False , True ]
37
38
USE_FP32_REDUCE_OPTS = [False , True ]
38
39
39
40
MARLIN_K_CHUNKS = [128 ]
@@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
194
195
@pytest .mark .parametrize ("mnk_factors" , MNK_FACTORS )
195
196
@pytest .mark .parametrize ("act_order" , ACT_ORDER_OPTS )
196
197
@pytest .mark .parametrize ("is_k_full" , K_FULL_OPTS )
198
+ @pytest .mark .parametrize ("use_atomic_add" , USE_ATOMIC_ADD_OPTS )
197
199
@pytest .mark .parametrize ("use_fp32_reduce" , USE_FP32_REDUCE_OPTS )
198
200
def test_gptq_marlin_gemm (
199
201
k_chunk ,
@@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
203
205
mnk_factors ,
204
206
act_order ,
205
207
is_k_full ,
208
+ use_atomic_add ,
206
209
use_fp32_reduce ,
207
210
):
208
211
m_factor , n_factor , k_factor = mnk_factors
@@ -232,7 +235,8 @@ def test_gptq_marlin_gemm(
232
235
torch .ops ._C .gptq_marlin_gemm ,
233
236
(a_input , marlin_q_w , marlin_s , marlin_zp , g_idx , sort_indices ,
234
237
workspace .scratch , quant_type .id , a_input .shape [0 ], b_weight .shape [1 ],
235
- a_input .shape [1 ], is_k_full , False , use_fp32_reduce , False ),
238
+ a_input .shape [1 ], is_k_full , False , use_atomic_add , use_fp32_reduce ,
239
+ False ),
236
240
test_utils = DEFAULT_OPCHECK_TEST_UTILS )
237
241
238
242
output = ops .gptq_marlin_gemm (
@@ -249,6 +253,7 @@ def test_gptq_marlin_gemm(
249
253
a_input .shape [1 ],
250
254
is_k_full = is_k_full ,
251
255
has_zp = False ,
256
+ use_atomic_add = use_atomic_add ,
252
257
use_fp32_reduce = use_fp32_reduce ,
253
258
is_zp_float = False ,
254
259
)
0 commit comments