@@ -599,10 +599,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
599
599
for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
600
600
for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
601
601
602
- // absolute max
603
- const float amax = MAX (
604
- MAX (vgetq_lane_f32 (amaxv [0 ], 0 ), vgetq_lane_f32 (amaxv [0 ], 1 )),
605
- MAX (vgetq_lane_f32 (amaxv [0 ], 2 ), vgetq_lane_f32 (amaxv [0 ], 3 )));
602
+ const float amax = vmaxvq_f32 (amaxv [0 ]);
606
603
607
604
const float d = amax / ((1 << 3 ) - 1 );
608
605
const float id = d ? 1.0f /d : 0.0f ;
@@ -924,7 +921,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
924
921
float32x4_t minv [8 ];
925
922
float32x4_t maxv [8 ];
926
923
927
- for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
924
+ for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * QK + 4 * l );
928
925
929
926
for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
930
927
for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
@@ -947,7 +944,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
947
944
948
945
for (int l = 0 ; l < 8 ; l ++ ) {
949
946
const float32x4_t v = vmulq_n_f32 (vsubq_f32 (srcv [l ], minv0 ), id );
950
- const int32x4_t vi = vcvtq_s32_f32 (v );
947
+ const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (0.5f )); // needed to round to nearest
948
+ const int32x4_t vi = vcvtq_s32_f32 (vf );
951
949
952
950
y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
953
951
y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
0 commit comments