Skip to content

Commit 1458451

Browse files
committed
update test
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent b3636a7 commit 1458451

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/kernels/test_marlin_gemm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
ACT_ORDER_OPTS = [False, True]
3636
K_FULL_OPTS = [False, True]
37+
USE_ATOMIC_ADD_OPTS = [False, True]
3738
USE_FP32_REDUCE_OPTS = [False, True]
3839

3940
MARLIN_K_CHUNKS = [128]
@@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
194195
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
195196
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
196197
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
198+
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
197199
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
198200
def test_gptq_marlin_gemm(
199201
k_chunk,
@@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
203205
mnk_factors,
204206
act_order,
205207
is_k_full,
208+
use_atomic_add,
206209
use_fp32_reduce,
207210
):
208211
m_factor, n_factor, k_factor = mnk_factors
@@ -232,7 +235,8 @@ def test_gptq_marlin_gemm(
232235
torch.ops._C.gptq_marlin_gemm,
233236
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
234237
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
235-
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
238+
a_input.shape[1], is_k_full, False, use_atomic_add, use_fp32_reduce,
239+
False),
236240
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
237241

238242
output = ops.gptq_marlin_gemm(
@@ -249,6 +253,7 @@ def test_gptq_marlin_gemm(
249253
a_input.shape[1],
250254
is_k_full=is_k_full,
251255
has_zp=False,
256+
use_atomic_add=use_atomic_add,
252257
use_fp32_reduce=use_fp32_reduce,
253258
is_zp_float=False,
254259
)

0 commit comments

Comments
 (0)