@@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
1791
1791
const int8x16_t y1_l = vld1q_s8 (b_y1 -> qs );
1792
1792
const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
1793
1793
1794
- float32_t _scale [4 ] = { GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1795
- GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
1796
- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1797
- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )};
1798
-
1794
+ float32_t _scale [4 ] = {
1795
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1796
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
1797
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1798
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
1799
+ };
1799
1800
float32x4_t scale = vld1q_f32 (_scale );
1800
1801
1801
1802
int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -1811,7 +1812,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
1811
1812
int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
1812
1813
1813
1814
sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
1814
- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
1815
+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
1815
1816
}
1816
1817
1817
1818
float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
@@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
2347
2348
const block_q8_1 * restrict b_y0 = & vy0 [i ];
2348
2349
const block_q8_1 * restrict b_y1 = & vy1 [i ];
2349
2350
2350
- float32_t summs_t [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2351
- GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2352
- GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s ),
2353
- GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s )};
2351
+ float32_t summs_t [4 ] = {
2352
+ GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2353
+ GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2354
+ GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s ),
2355
+ GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s )
2356
+ };
2354
2357
summs0 = vaddq_f32 (summs0 , vld1q_f32 (summs_t ));
2355
2358
2356
2359
const uint8x16_t m4b = vdupq_n_u8 (0x0F );
@@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
2371
2374
const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
2372
2375
2373
2376
// mmla into int32x4_t
2374
- float32_t _scale [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> d )* b_y0 -> d ,
2375
- GGML_FP16_TO_FP32 (b_x0 -> d )* b_y1 -> d ,
2376
- GGML_FP16_TO_FP32 (b_x1 -> d )* b_y0 -> d ,
2377
- GGML_FP16_TO_FP32 (b_x1 -> d )* b_y1 -> d };
2377
+ float32_t _scale [4 ] = {
2378
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
2379
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
2380
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
2381
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
2382
+ };
2378
2383
float32x4_t scale = vld1q_f32 (_scale );
2379
2384
2380
2385
int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -2389,15 +2394,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
2389
2394
int8x16_t r2 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
2390
2395
int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
2391
2396
sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
2392
- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
2397
+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
2393
2398
}
2394
2399
2395
- float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
2400
+ float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
2396
2401
float32x4_t sumv2 = vzip1q_f32 (sumv0 , sumv1 );
2402
+
2397
2403
sumv2 = vaddq_f32 (sumv2 , summs0 );
2398
2404
2399
2405
vst1_f32 (s , vget_low_f32 (sumv2 ));
2400
2406
vst1_f32 (s + bs , vget_high_f32 (sumv2 ));
2407
+
2401
2408
return ;
2402
2409
}
2403
2410
#endif
@@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3374
3381
const int8x16_t y1_l = vld1q_s8 (b_y1 -> qs );
3375
3382
const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
3376
3383
3377
- float32_t _scale [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3378
- GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
3379
- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3380
- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )};
3384
+ float32_t _scale [4 ] = {
3385
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3386
+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
3387
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3388
+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
3389
+ };
3381
3390
float32x4_t scale = vld1q_f32 (_scale );
3382
3391
3383
3392
int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -3393,13 +3402,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3393
3402
int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
3394
3403
3395
3404
sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
3396
- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
3405
+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
3397
3406
}
3398
- float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
3407
+
3408
+ float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
3399
3409
float32x4_t sumv2 = vzip1q_f32 (sumv0 , sumv1 );
3400
3410
3401
- vst1_f32 (s , vget_low_f32 (sumv2 ));
3411
+ vst1_f32 (s , vget_low_f32 (sumv2 ));
3402
3412
vst1_f32 (s + bs , vget_high_f32 (sumv2 ));
3413
+
3403
3414
return ;
3404
3415
}
3405
3416
#endif
0 commit comments