Skip to content

Commit 7a766df

Browse files
committed
[experimental][kleidi] Reduce template types for tests
1 parent 4271183 commit 7a766df

File tree

2 files changed

+24
-61
lines changed

2 files changed

+24
-61
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ TEST(
357357
// #ifdef TORCHAO_ENABLE_KLEIDI
358358
// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI
359359

360-
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
360+
template <bool has_bias, bool has_clamp>
361361
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
362362
int m,
363363
int k,
@@ -369,8 +369,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
369369
k,
370370
n,
371371
group_size,
372-
weight_nbit,
373-
has_weight_zeros,
372+
/*weight_nbit=*/4,
373+
/*has_weight_zeros*/false,
374374
has_bias,
375375
has_clamp,
376376
/*weight_scale_bf16_round_trip=*/true);
@@ -421,8 +421,6 @@ TEST(
421421
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
422422
k_eq_gs_32) {
423423
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
424-
4 /*weight_nbit*/,
425-
false /*has_weight_zeros*/,
426424
false /*has_bias*/,
427425
false /*has_clamp*/>(
428426
/*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32);
@@ -432,8 +430,6 @@ TEST(
432430
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
433431
large_k_n_gs32) {
434432
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
435-
4 /*weight_nbit*/,
436-
false /*has_weight_zeros*/,
437433
false /*has_bias*/,
438434
false /*has_clamp*/>(
439435
/*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32);
@@ -443,8 +439,6 @@ TEST(
443439
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
444440
even_n_gs32) {
445441
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
446-
4 /*weight_nbit*/,
447-
false /*has_weight_zeros*/,
448442
false /*has_bias*/,
449443
false /*has_clamp*/>(
450444
/*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32);
@@ -454,14 +448,21 @@ TEST(
454448
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
455449
k_eq_gs128) {
456450
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
457-
4 /*weight_nbit*/,
458-
false /*has_weight_zeros*/,
459451
false /*has_bias*/,
460452
false /*has_clamp*/>(
461453
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
462454
}
463455

464-
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
456+
TEST(
457+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
458+
clamp_k_eq_gs128) {
459+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
460+
false /*has_bias*/,
461+
true /*has_clamp*/>(
462+
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
463+
}
464+
465+
template <bool has_bias, bool has_clamp>
465466
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
466467
int m,
467468
int k,
@@ -473,8 +474,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
473474
k,
474475
n,
475476
group_size,
476-
weight_nbit,
477-
has_weight_zeros,
477+
/*weight_nbit=*/4,
478+
/*has_weight_zeros=*/false,
478479
has_bias,
479480
has_clamp,
480481
/*weight_scale_bf16_round_trip=*/true);
@@ -525,8 +526,6 @@ TEST(
525526
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
526527
k_eq_gs_32) {
527528
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
528-
4 /*weight_nbit*/,
529-
false /*has_weight_zeros*/,
530529
false /*has_bias*/,
531530
false /*has_clamp*/>(
532531
/*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32);
@@ -536,8 +535,6 @@ TEST(
536535
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
537536
large_k_n_gs32) {
538537
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
539-
4 /*weight_nbit*/,
540-
false /*has_weight_zeros*/,
541538
false /*has_bias*/,
542539
false /*has_clamp*/>(
543540
/*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32);
@@ -547,8 +544,6 @@ TEST(
547544
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
548545
even_n_gs32) {
549546
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
550-
4 /*weight_nbit*/,
551-
false /*has_weight_zeros*/,
552547
false /*has_bias*/,
553548
false /*has_clamp*/>(
554549
/*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32);
@@ -558,11 +553,18 @@ TEST(
558553
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
559554
k_eq_gs128) {
560555
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
561-
4 /*weight_nbit*/,
562-
false /*has_weight_zeros*/,
563556
false /*has_bias*/,
564557
false /*has_clamp*/>(
565558
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
566559
}
560+
561+
TEST(
562+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
563+
clamp_k_eq_gs128) {
564+
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
565+
false /*has_bias*/,
566+
true /*has_clamp*/>(
567+
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
568+
}
567569
// #endif // defined(TORCHAO_ENABLE_KLEIDI)
568570
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case {
235235
zero,
236236
qmin,
237237
qmax);
238-
// std::fill(weight_qvals.begin(), weight_qvals.end(), -7);
239238
}
240239

241240
std::vector<float> bias(m, 0.0);
@@ -277,44 +276,6 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case {
277276
}
278277
}
279278

280-
#if 0 // Alternate reference implementation for debugging.
281-
auto num_groups = k / weight_group_size;
282-
for (int m_idx = 0; m_idx < m; m_idx++) {
283-
for (int n_idx = 0; n_idx < n; n_idx++) {
284-
int32_t result_idx = m_idx * n + n_idx;
285-
float weights_fsum = 0.0;
286-
for (int g_idx = 0; g_idx < num_groups; g_idx++) {
287-
int32_t weights_qsum = 0;
288-
int32_t acc_i32 = 0;
289-
for (int k_idx = 0; k_idx < weight_group_size; k_idx++) {
290-
const int32_t activation_idx = m_idx * k + g_idx * weight_group_size + k_idx;
291-
const int32_t weight_idx = n_idx * k + g_idx * weight_group_size + k_idx;
292-
293-
const int32_t weight_qval = weight_qvals[weight_idx];
294-
const int32_t activation_qval = activation_qvals[activation_idx];
295-
296-
weights_qsum += weight_qval;
297-
acc_i32 += weight_qval * activation_qval;
298-
}
299-
// For each group, we have a weight scale
300-
const int32_t weight_scale_idx = n_idx * num_groups + g_idx;
301-
const float weight_scale = weight_scales[weight_scale_idx]; // already rounded trip to bf16
302-
expected_output[result_idx] += (float) acc_i32 * weight_scales[weight_scale_idx];
303-
weights_fsum += weights_qsum * weight_scale;
304-
}
305-
// For each output channel, we have an activation scale
306-
const int32_t activation_zero_point = activation_zeros[m_idx];
307-
const float activation_scale = activation_scales[m_idx];
308-
expected_output[result_idx] -= activation_zero_point * weights_fsum;
309-
expected_output[result_idx] *= activation_scale;
310-
expected_output[result_idx] += bias[m_idx];
311-
if (has_clamp) {
312-
expected_output[result_idx] = std::min(std::max(expected_output[result_idx], clamp_min), clamp_max);
313-
}
314-
}
315-
}
316-
#endif
317-
318279
// Return test case
319280
return channelwise_8bit_activation_groupwise_lowbit_weight_test_case(
320281
m,

0 commit comments

Comments
 (0)