@@ -692,13 +692,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
692
692
693
693
for (int i = 0 ; i < nb ; i ++ ) {
694
694
float amax = 0.0f ; // absolute max
695
+ float max = 0.0f ;
695
696
696
697
for (int l = 0 ; l < QK4_0 ; l ++ ) {
697
698
const float v = x [i * QK4_0 + l ];
698
- amax = MAX (amax , fabsf (v ));
699
+ if (amax < fabsf (v )) {
700
+ amax = fabsf (v );
701
+ max = v ;
702
+ }
699
703
}
700
704
701
- const float d = amax / (( 1 << 3 ) - 1 ) ;
705
+ const float d = max / -8 ;
702
706
const float id = d ? 1.0f /d : 0.0f ;
703
707
704
708
y [i ].d = d ;
@@ -707,8 +711,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
707
711
const float v0 = x [i * QK4_0 + l + 0 ]* id ;
708
712
const float v1 = x [i * QK4_0 + l + 1 ]* id ;
709
713
710
- const uint8_t vi0 = ( int8_t )roundf (v0 ) + 8 ;
711
- const uint8_t vi1 = ( int8_t )roundf (v1 ) + 8 ;
714
+ const uint8_t vi0 = MIN ( 15 , ( int8_t )roundf (v0 ) + 8 ) ;
715
+ const uint8_t vi1 = MIN ( 15 , ( int8_t )roundf (v1 ) + 8 ) ;
712
716
713
717
assert (vi0 < 16 );
714
718
assert (vi1 < 16 );
@@ -728,28 +732,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
728
732
729
733
#if defined(__POWER9_VECTOR__ )
730
734
const vector float v85 = vec_splats (8.5f );
735
+ const vector signed int v15 = vec_splats (15 );
731
736
for (int i = 0 ; i < nb ; i ++ ) {
732
- float amax = 0.0f ; // absolute max
737
+ float max = 0.0f ;
738
+ float min = 0.0f ;
733
739
734
740
vector float srcv [8 ];
735
- vector float asrcv [8 ];
736
- vector float amaxv [8 ];
741
+ vector float maxv [8 ];
742
+ vector float minv [8 ];
737
743
738
744
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = * (vector float * )(x + i * 32 + 4 * l );
739
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vec_abs (srcv [l ]);
740
-
741
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
742
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
743
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [2 ]);
744
- amaxv [4 ] = vec_max (amaxv [4 ], amaxv [6 ]);
745
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
746
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [4 ]);
747
-
748
- amax = MAX (
749
- MAX (vec_extract (amaxv [0 ], 0 ), vec_extract (amaxv [0 ], 1 )),
750
- MAX (vec_extract (amaxv [0 ], 2 ), vec_extract (amaxv [0 ], 3 )));
751
-
752
- const float d = amax / ((1 << 3 ) - 1 );
745
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
746
+
747
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
748
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
749
+ maxv [0 ] = vec_max (maxv [0 ], maxv [2 ]);
750
+ maxv [4 ] = vec_max (maxv [4 ], maxv [6 ]);
751
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
752
+ maxv [0 ] = vec_max (maxv [0 ], maxv [4 ]);
753
+
754
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vec_min (asrcv [2 * l ], asrcv [2 * l + 1 ]);
755
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
756
+ minv [0 ] = vec_min (minv [0 ], minv [2 ]);
757
+ minv [4 ] = vec_min (minv [4 ], minv [6 ]);
758
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
759
+ minv [0 ] = vec_min (minv [0 ], minv [4 ]);
760
+
761
+
762
+ max = MAX (
763
+ MAX (vec_extract (maxv [0 ], 0 ), vec_extract (maxv [0 ], 1 )),
764
+ MAX (vec_extract (maxv [0 ], 2 ), vec_extract (maxv [0 ], 3 )));
765
+ min = MIN (
766
+ MIN (vec_extract (minv [0 ], 0 ), vec_extract (minv [0 ], 1 )),
767
+ MIN (vec_extract (minv [0 ], 2 ), vec_extract (minv [0 ], 3 )));
768
+
769
+ const float magnitude = max >= fabsf (min ) ? max : min ;
770
+ const float d = magnitude / -8 ;
753
771
const float id = d ? 1.0 /d : 0.0 ;
754
772
755
773
y [i ].d = d ;
@@ -759,27 +777,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
759
777
for (int l = 0 ; l < 8 ; l ++ ) {
760
778
const vector float vf = vec_madd (srcv [l ], vid , v85 );
761
779
const vector signed int vi = vec_signed (vf );
780
+ const vector signed int vc = vec_min (vi , v15 );
762
781
763
- pb [2 * l + 0 ] = vec_extract (vi , 0 ) | (vec_extract (vi , 1 ) << 4 );
764
- pb [2 * l + 1 ] = vec_extract (vi , 2 ) | (vec_extract (vi , 3 ) << 4 );
782
+ pb [2 * l + 0 ] = vec_extract (vc , 0 ) | (vec_extract (vc , 1 ) << 4 );
783
+ pb [2 * l + 1 ] = vec_extract (vc , 2 ) | (vec_extract (vc , 3 ) << 4 );
765
784
}
766
785
}
767
786
#elif __ARM_NEON
768
787
for (int i = 0 ; i < nb ; i ++ ) {
769
788
float32x4_t srcv [8 ];
770
- float32x4_t asrcv [8 ];
771
- float32x4_t amaxv [8 ];
789
+ float32x4_t maxv [8 ];
790
+ float32x4_t minv [8 ];
772
791
773
792
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
774
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
775
793
776
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
777
- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
778
- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
794
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vmaxq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
795
+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = vmaxq_f32 (maxv [4 * l ], maxv [4 * l + 2 ]);
796
+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = vmaxq_f32 (maxv [8 * l ], maxv [8 * l + 4 ]);
779
797
780
- const float amax = vmaxvq_f32 (amaxv [0 ]);
798
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
799
+ for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
800
+ for (int l = 0 ; l < 1 ; l ++ ) minv [8 * l ] = vminq_f32 (minv [8 * l ], minv [8 * l + 4 ]);
781
801
782
- const float d = amax / ((1 << 3 ) - 1 );
802
+ const float max = vmaxvq_f32 (maxv [0 ]);
803
+ const float min = vminvq_f32 (minv [0 ]);
804
+
805
+ const float magnitude = max >= fabsf (min ) ? max : min ;
806
+ const float d = magnitude / -8 ;
783
807
const float id = d ? 1.0f /d : 0.0f ;
784
808
785
809
y [i ].d = d ;
@@ -788,9 +812,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
788
812
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
789
813
const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (8.5f ));
790
814
const int32x4_t vi = vcvtq_s32_f32 (vf );
815
+ const int32x4_t vc = vminq_s32 (vi , vdupq_n_s32 (15 ));
791
816
792
- y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
793
- y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
817
+ y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vc , 0 ) | (vgetq_lane_s32 (vc , 1 ) << 4 );
818
+ y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vc , 2 ) | (vgetq_lane_s32 (vc , 3 ) << 4 );
794
819
}
795
820
}
796
821
#elif defined(__AVX2__ )
@@ -802,22 +827,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
802
827
__m256 v3 = _mm256_loadu_ps ( x + 24 );
803
828
x += 32 ;
804
829
805
- // Compute max(abs(e)) for the block
806
- const __m256 signBit = _mm256_set1_ps ( -0.0f );
807
- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
808
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
809
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
810
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
830
+ // Compute max for the block
831
+ __m256 max = _mm256_max_ps ( v0 , v1 );
832
+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
833
+ max = _mm256_max_ps ( max , maxTmp );
811
834
812
- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
835
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
813
836
max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
814
837
max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
815
838
const float maxScalar = _mm_cvtss_f32 ( max4 );
816
839
840
+ // Compute min for the block
841
+ __m256 min = _mm256_min_ps ( v0 , v1 );
842
+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
843
+ min = _mm256_min_ps ( min , minTmp );
844
+
845
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
846
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
847
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
848
+ const float minScalar = _mm_cvtss_f32 ( min4 );
849
+
817
850
// Quantize these floats
818
- const float d = maxScalar / 7.0f ;
851
+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
852
+ const float d = magnitude / -8.0f ;
819
853
y [i ].d = d ;
820
- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
854
+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
821
855
const __m256 mul = _mm256_set1_ps ( id );
822
856
823
857
// Apply the multiplier
@@ -850,9 +884,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
850
884
const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
851
885
i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
852
886
853
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
887
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
854
888
const __m256i off = _mm256_set1_epi8 ( 8 );
855
889
i0 = _mm256_add_epi8 ( i0 , off );
890
+ const __m256i maxNibble = _mm256_set1_epi8 ( 15 );
891
+ i0 = _mm256_min_epi8 ( i0 , maxNibble );
856
892
857
893
// Compress the vector into 4 bit/value, and store
858
894
__m128i res = packNibbles ( i0 );
@@ -867,22 +903,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
867
903
__m256 v3 = _mm256_loadu_ps ( x + 24 );
868
904
x += 32 ;
869
905
870
- // Compute max(abs(e)) for the block
871
- const __m256 signBit = _mm256_set1_ps ( -0.0f );
872
- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
873
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
874
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
875
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
906
+ // Compute max for the block
907
+ __m256 max = _mm256_max_ps ( v0 , v1 );
908
+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
909
+ max = _mm256_max_ps ( max , maxTmp );
876
910
877
- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
911
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
878
912
max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
879
913
max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
880
914
const float maxScalar = _mm_cvtss_f32 ( max4 );
881
915
916
+ // Compute min for the block
917
+ __m256 min = _mm256_min_ps ( v0 , v1 );
918
+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
919
+ min = _mm256_min_ps ( min , minTmp );
920
+
921
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
922
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
923
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
924
+ const float minScalar = _mm_cvtss_f32 ( min4 );
925
+
882
926
// Quantize these floats
883
- const float d = maxScalar / 7.0f ;
927
+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
928
+ const float d = magnitude / -8.0f ;
884
929
y [i ].d = d ;
885
- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
930
+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
886
931
const __m256 mul = _mm256_set1_ps ( id );
887
932
888
933
// Apply the multiplier
@@ -923,35 +968,46 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
923
968
ni0 = _mm_packs_epi16 ( ni0 , ni2 );
924
969
ni4 = _mm_packs_epi16 ( ni4 , ni6 );
925
970
926
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
927
- const __m128i off = _mm_set1_epi8 ( 8 );
971
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
972
+ const __m128i off = _mm_set1_epi8 ( 8 );
928
973
ni0 = _mm_add_epi8 ( ni0 , off );
929
974
ni4 = _mm_add_epi8 ( ni4 , off );
975
+ const __m128i maxNibble = _mm_set1_epi8 ( 15 );
976
+ ni0 = _mm_min_epi8 ( ni0 , maxNibble );
977
+ ni4 = _mm_min_epi8 ( ni4 , maxNibble );
930
978
931
979
// Compress the vector into 4 bit/value, and store
932
980
__m128i res = packNibbles ( ni0 , ni4 );
933
981
_mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
934
982
}
935
983
#elif defined(__wasm_simd128__ )
936
984
for (int i = 0 ; i < nb ; i ++ ) {
937
- float amax = 0.0f ; // absolute max
985
+ float max = 0.0f ;
986
+ float min = 0.0f ;
938
987
939
988
v128_t srcv [8 ];
940
- v128_t asrcv [8 ];
941
- v128_t amaxv [8 ];
989
+ v128_t maxv [8 ];
990
+ v128_t minv [8 ];
942
991
943
992
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = wasm_v128_load (x + i * 32 + 4 * l );
944
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = wasm_f32x4_abs (srcv [l ]);
945
993
946
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = wasm_f32x4_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
947
- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = wasm_f32x4_max (amaxv [4 * l ], amaxv [4 * l + 2 ]);
948
- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = wasm_f32x4_max (amaxv [8 * l ], amaxv [8 * l + 4 ]);
994
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = wasm_f32x4_max (srcv [2 * l ], srcv [2 * l + 1 ]);
995
+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = wasm_f32x4_max (maxv [4 * l ], maxv [4 * l + 2 ]);
996
+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = wasm_f32x4_max (maxv [8 * l ], maxv [8 * l + 4 ]);
949
997
950
- amax = MAX (
951
- MAX ( wasm_f32x4_extract_lane ( amaxv [ 0 ], 0 ), wasm_f32x4_extract_lane ( amaxv [ 0 ], 1 )),
952
- MAX ( wasm_f32x4_extract_lane ( amaxv [ 0 ], 2 ), wasm_f32x4_extract_lane ( amaxv [ 0 ], 3 )) );
998
+ for ( int l = 0 ; l < 4 ; l ++ ) minv [ 2 * l ] = wasm_f32x4_min ( srcv [ 2 * l ], srcv [ 2 * l + 1 ]);
999
+ for ( int l = 0 ; l < 2 ; l ++ ) minv [ 4 * l ] = wasm_f32x4_min ( minv [ 4 * l ], minv [ 4 * l + 2 ]);
1000
+ for ( int l = 0 ; l < 1 ; l ++ ) minv [ 8 * l ] = wasm_f32x4_min ( minv [ 8 * l ], minv [ 8 * l + 4 ] );
953
1001
954
- const float d = amax / ((1 << 3 ) - 1 );
1002
+ max = MAX (
1003
+ MAX (wasm_f32x4_extract_lane (maxv [0 ], 0 ), wasm_f32x4_extract_lane (maxv [0 ], 1 )),
1004
+ MAX (wasm_f32x4_extract_lane (maxv [0 ], 2 ), wasm_f32x4_extract_lane (maxv [0 ], 3 )));
1005
+ min = MIN (
1006
+ MIN (wasm_f32x4_extract_lane (minv [0 ], 0 ), wasm_f32x4_extract_lane (minv [0 ], 1 )),
1007
+ MIN (wasm_f32x4_extract_lane (minv [0 ], 2 ), wasm_f32x4_extract_lane (minv [0 ], 3 )));
1008
+
1009
+ const float magnitude = max >= fabsf (min ) ? max : min ;
1010
+ const float d = magnitude / -8 ;
955
1011
const float id = d ? 1.0 /d : 0.0 ;
956
1012
957
1013
y [i ].d = d ;
@@ -960,9 +1016,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
960
1016
const v128_t v = wasm_f32x4_mul (srcv [l ], wasm_f32x4_splat (id ));
961
1017
const v128_t vf = wasm_f32x4_add (v , wasm_f32x4_splat (8.5f ));
962
1018
const v128_t vi = wasm_i32x4_trunc_sat_f32x4 (vf );
1019
+ const v128_t vc = wasm_i32x4_min_u (vi , wasm_i32x4_splat (15 ));
963
1020
964
- y [i ].qs [2 * l + 0 ] = wasm_i32x4_extract_lane (vi , 0 ) | (wasm_i32x4_extract_lane (vi , 1 ) << 4 );
965
- y [i ].qs [2 * l + 1 ] = wasm_i32x4_extract_lane (vi , 2 ) | (wasm_i32x4_extract_lane (vi , 3 ) << 4 );
1021
+ y [i ].qs [2 * l + 0 ] = wasm_i32x4_extract_lane (vc , 0 ) | (wasm_i32x4_extract_lane (vc , 1 ) << 4 );
1022
+ y [i ].qs [2 * l + 1 ] = wasm_i32x4_extract_lane (vc , 2 ) | (wasm_i32x4_extract_lane (vc , 3 ) << 4 );
966
1023
}
967
1024
}
968
1025
#else
@@ -1143,13 +1200,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1143
1200
1144
1201
for (int i = 0 ; i < nb ; i ++ ) {
1145
1202
float amax = 0.0f ; // absolute max
1203
+ float max = 0.0f ;
1146
1204
1147
1205
for (int l = 0 ; l < QK4_2 ; l ++ ) {
1148
1206
const float v = x [i * QK4_2 + l ];
1149
- amax = MAX (amax , fabsf (v ));
1207
+ if (amax < fabsf (v )) {
1208
+ amax = fabsf (v );
1209
+ max = v ;
1210
+ }
1150
1211
}
1151
1212
1152
- const float d = amax / (( 1 << 3 ) - 1 ) ;
1213
+ const float d = max / -8 ;
1153
1214
1154
1215
const float id = d ? 1.0f /d : 0.0f ;
1155
1216
@@ -1159,8 +1220,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1159
1220
const float v0 = x [i * QK4_2 + l + 0 ]* id ;
1160
1221
const float v1 = x [i * QK4_2 + l + 1 ]* id ;
1161
1222
1162
- const uint8_t vi0 = ( uint8_t )(v0 + 8.5f );
1163
- const uint8_t vi1 = ( uint8_t )(v1 + 8.5f );
1223
+ const uint8_t vi0 = MIN ( 15 , ( uint8_t )(v0 + 8.5f ) );
1224
+ const uint8_t vi1 = MIN ( 15 , ( uint8_t )(v1 + 8.5f ) );
1164
1225
1165
1226
assert (vi0 < 16 );
1166
1227
assert (vi1 < 16 );
@@ -1254,9 +1315,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
1254
1315
1255
1316
block_q4_2 * restrict y = vy ;
1256
1317
1257
- //quantize_row_q4_2_reference(x, y, k);
1258
- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1259
- quantize_row_q4_2_rmse (x , y , k );
1318
+ quantize_row_q4_2_reference (x , y , k );
1260
1319
}
1261
1320
1262
1321
static void quantize_row_q4_3_reference (const float * restrict x , block_q4_3 * restrict y , int k ) {
@@ -1807,7 +1866,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1807
1866
[GGML_TYPE_Q4_2 ] = {
1808
1867
.dequantize_row_q = dequantize_row_q4_2 ,
1809
1868
.quantize_row_q = quantize_row_q4_2 ,
1810
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_rmse , // quantize_row_q4_2_reference,
1869
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_reference ,
1811
1870
.quantize_row_q_dot = quantize_row_q8_0 ,
1812
1871
.vec_dot_q = ggml_vec_dot_q4_2_q8_0 ,
1813
1872
},
@@ -12144,8 +12203,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12144
12203
for (int j = 0 ; j < n ; j += k ) {
12145
12204
block_q4_2 * restrict y = (block_q4_2 * )dst + j /QK4_2 ;
12146
12205
12147
- //quantize_row_q4_2_reference(src + j, y, k);
12148
- quantize_row_q4_2_rmse (src + j , y , k );
12206
+ quantize_row_q4_2_reference (src + j , y , k );
12149
12207
12150
12208
for (int i = 0 ; i < nb ; i ++ ) {
12151
12209
for (int l = 0 ; l < QK4_2 ; l += 2 ) {
0 commit comments