@@ -357,7 +357,7 @@ TEST(
357
357
// #ifdef TORCHAO_ENABLE_KLEIDI
358
358
// TODO: Wire up the the compile defination for TORCHAO_ENABLE_KLEIDI
359
359
360
- template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
360
+ template <bool has_bias, bool has_clamp>
361
361
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod (
362
362
int m,
363
363
int k,
@@ -369,8 +369,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
369
369
k,
370
370
n,
371
371
group_size,
372
- weight_nbit,
373
- has_weight_zeros,
372
+ /* weight_nbit= */ 4 ,
373
+ /* has_weight_zeros*/ false ,
374
374
has_bias,
375
375
has_clamp,
376
376
/* weight_scale_bf16_round_trip=*/ true );
@@ -421,8 +421,6 @@ TEST(
421
421
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
422
422
k_eq_gs_32) {
423
423
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
424
- 4 /* weight_nbit*/ ,
425
- false /* has_weight_zeros*/ ,
426
424
false /* has_bias*/ ,
427
425
false /* has_clamp*/ >(
428
426
/* m=*/ 1 , /* k=*/ 32 , /* n=*/ 4 , /* group_size=*/ 32 );
@@ -432,8 +430,6 @@ TEST(
432
430
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
433
431
large_k_n_gs32) {
434
432
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
435
- 4 /* weight_nbit*/ ,
436
- false /* has_weight_zeros*/ ,
437
433
false /* has_bias*/ ,
438
434
false /* has_clamp*/ >(
439
435
/* m=*/ 1 , /* k=*/ 1024 , /* n=*/ 512 , /* group_size=*/ 32 );
@@ -443,8 +439,6 @@ TEST(
443
439
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
444
440
even_n_gs32) {
445
441
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
446
- 4 /* weight_nbit*/ ,
447
- false /* has_weight_zeros*/ ,
448
442
false /* has_bias*/ ,
449
443
false /* has_clamp*/ >(
450
444
/* m=*/ 1 , /* k=*/ 1024 , /* n=*/ 182 , /* group_size=*/ 32 );
@@ -454,14 +448,21 @@ TEST(
454
448
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
455
449
k_eq_gs128) {
456
450
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
457
- 4 /* weight_nbit*/ ,
458
- false /* has_weight_zeros*/ ,
459
451
false /* has_bias*/ ,
460
452
false /* has_clamp*/ >(
461
453
/* m=*/ 1 , /* k=*/ 128 , /* n=*/ 182 , /* group_size=*/ 128 );
462
454
}
463
455
464
- template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
456
+ TEST (
457
+ test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
458
+ clamp_k_eq_gs128) {
459
+ test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod<
460
+ false /* has_bias*/ ,
461
+ true /* has_clamp*/ >(
462
+ /* m=*/ 1 , /* k=*/ 128 , /* n=*/ 182 , /* group_size=*/ 128 );
463
+ }
464
+
465
+ template <bool has_bias, bool has_clamp>
465
466
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod (
466
467
int m,
467
468
int k,
@@ -473,8 +474,8 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
473
474
k,
474
475
n,
475
476
group_size,
476
- weight_nbit,
477
- has_weight_zeros,
477
+ /* weight_nbit= */ 4 ,
478
+ /* has_weight_zeros= */ false ,
478
479
has_bias,
479
480
has_clamp,
480
481
/* weight_scale_bf16_round_trip=*/ true );
@@ -525,8 +526,6 @@ TEST(
525
526
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
526
527
k_eq_gs_32) {
527
528
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
528
- 4 /* weight_nbit*/ ,
529
- false /* has_weight_zeros*/ ,
530
529
false /* has_bias*/ ,
531
530
false /* has_clamp*/ >(
532
531
/* m=*/ 1 , /* k=*/ 32 , /* n=*/ 4 , /* group_size=*/ 32 );
@@ -536,8 +535,6 @@ TEST(
536
535
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
537
536
large_k_n_gs32) {
538
537
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
539
- 4 /* weight_nbit*/ ,
540
- false /* has_weight_zeros*/ ,
541
538
false /* has_bias*/ ,
542
539
false /* has_clamp*/ >(
543
540
/* m=*/ 1 , /* k=*/ 1024 , /* n=*/ 512 , /* group_size=*/ 32 );
@@ -547,8 +544,6 @@ TEST(
547
544
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
548
545
even_n_gs32) {
549
546
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
550
- 4 /* weight_nbit*/ ,
551
- false /* has_weight_zeros*/ ,
552
547
false /* has_bias*/ ,
553
548
false /* has_clamp*/ >(
554
549
/* m=*/ 1 , /* k=*/ 1024 , /* n=*/ 182 , /* group_size=*/ 32 );
@@ -558,11 +553,18 @@ TEST(
558
553
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
559
554
k_eq_gs128) {
560
555
test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
561
- 4 /* weight_nbit*/ ,
562
- false /* has_weight_zeros*/ ,
563
556
false /* has_bias*/ ,
564
557
false /* has_clamp*/ >(
565
558
/* m=*/ 1 , /* k=*/ 128 , /* n=*/ 182 , /* group_size=*/ 128 );
566
559
}
560
+
561
+ TEST (
562
+ test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
563
+ clamp_k_eq_gs128) {
564
+ test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod<
565
+ false /* has_bias*/ ,
566
+ true /* has_clamp*/ >(
567
+ /* m=*/ 1 , /* k=*/ 128 , /* n=*/ 182 , /* group_size=*/ 128 );
568
+ }
567
569
// #endif // defined(TORCHAO_ENABLE_KLEIDI)
568
570
#endif // defined(__aarch64__) || defined(__ARM_NEON)
0 commit comments