@@ -1060,7 +1060,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
10601060 const v128_t v = wasm_f32x4_mul (srcv [l ], wasm_f32x4_splat (id ));
10611061 const v128_t vf = wasm_f32x4_add (v , wasm_f32x4_splat (8.5f ));
10621062 const v128_t vi = wasm_i32x4_trunc_sat_f32x4 (vf );
1063- const v128_t vc = wasm_i32x4_min (vi , wasm_i32x4_splat (15 ));
1063+ const v128_t vc = wasm_i32x4_min_u (vi , wasm_i32x4_splat (15 ));
10641064
10651065 y [i ].qs [2 * l + 0 ] = wasm_i32x4_extract_lane (vc , 0 ) | (wasm_i32x4_extract_lane (vc , 1 ) << 4 );
10661066 y [i ].qs [2 * l + 1 ] = wasm_i32x4_extract_lane (vc , 2 ) | (wasm_i32x4_extract_lane (vc , 3 ) << 4 );
@@ -2658,35 +2658,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26582658 const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
26592659 const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
26602660
2661+ // interleave
2662+ const int8x16_t v0_0lz = vzip1q_s8 (v0_0ls , v0_0hs );
2663+ const int8x16_t v0_0hz = vzip2q_s8 (v0_0ls , v0_0hs );
2664+ const int8x16_t v0_1lz = vzip1q_s8 (v0_1ls , v0_1hs );
2665+ const int8x16_t v0_1hz = vzip2q_s8 (v0_1ls , v0_1hs );
2666+
26612667 // load y
26622668 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
26632669 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
26642670 const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
26652671 const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
26662672
2667- // interleave
2668- const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2669- const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2670- const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2671- const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2672-
26732673#if defined(__ARM_FEATURE_DOTPROD )
26742674 // dot product into int32x4_t
2675- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2676- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2675+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ), v0_0hz , v1_0h );
2676+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ), v0_1hz , v1_1h );
26772677
26782678 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
26792679 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
26802680#else
2681- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2682- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2683- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2684- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2685-
2686- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2687- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2688- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2689- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2681+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
2682+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
2683+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
2684+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
2685+
2686+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2687+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2688+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2689+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
26902690
26912691 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
26922692 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
0 commit comments