@@ -2446,10 +2446,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2446
2446
2447
2447
// TODO: add AVX / WASM SIMD / etc
2448
2448
#if defined(__ARM_NEON )
2449
- float sum00 = 0.0f ;
2450
- float sum01 = 0.0f ;
2451
- float sum10 = 0.0f ;
2452
- float sum11 = 0.0f ;
2449
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2450
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2453
2451
2454
2452
for (int i = 0 ; i < nb ; i += 2 ) {
2455
2453
const block_q4_1 * restrict x0 = & x [i + 0 ];
@@ -2480,20 +2478,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2480
2478
const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2481
2479
const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2482
2480
2483
- // Note: cannot use vaddvq_s8 because it overflows for 8-bit values
2484
- // TODO: is there a better way to do this?
2485
- sum00 += (x0 -> m * y0 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0ls ))) +
2486
- vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0hs ))));
2487
- sum01 += (x1 -> m * y1 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1ls ))) +
2488
- vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1hs ))));
2481
+ const int16x8_t s0i = vaddq_s16 (
2482
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2483
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2484
+
2485
+ const int16x8_t s1i = vaddq_s16 (
2486
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2487
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2488
+
2489
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2490
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2489
2491
2490
2492
#if defined(__ARM_FEATURE_DOTPROD )
2491
2493
// dot product into int32x4_t
2492
2494
const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2493
2495
const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2494
2496
2495
- sum10 += ( x0 -> d * y0 -> d ) * vaddvq_s32 ( p_0 );
2496
- sum11 += ( x1 -> d * y1 -> d ) * vaddvq_s32 ( p_1 );
2497
+ sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( p_0 ), x0 -> d * y0 -> d );
2498
+ sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( p_1 ), x1 -> d * y1 -> d );
2497
2499
#else
2498
2500
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2499
2501
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
@@ -2505,21 +2507,17 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2505
2507
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2506
2508
const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2507
2509
2508
- const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
2509
- const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
2510
-
2511
- const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
2512
- const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
2513
-
2514
- const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
2515
- const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
2510
+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2511
+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2512
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2513
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2516
2514
2517
- sum10 += x0 -> d * y0 -> d * vaddvq_s16 ( p_0 );
2518
- sum11 += x1 -> d * y1 -> d * vaddvq_s16 ( p_1 );
2515
+ sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vaddq_s32 ( pl0 , ph0 )), x0 -> d * y0 -> d );
2516
+ sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vaddq_s32 ( pl1 , ph1 )), x1 -> d * y1 -> d );
2519
2517
#endif
2520
2518
}
2521
2519
2522
- sumf = sum00 + sum01 + sum10 + sum11 ;
2520
+ sumf = vaddvq_f32 ( sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2523
2521
#else
2524
2522
// scalar
2525
2523
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments