@@ -687,11 +687,11 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
687687 }
688688 return 0.0f ;
689689 }
690+
690691 bool negative_scale = false;
691692 if (signed_scale && - nmin != nmax ) {
692693 // the max side should have the biggest range
693- // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
694- // or is it some other condition?
694+ // FIXME: this is not always the best sign
695695 if ((x [amax_i ] < 0.0f ) == (- nmin < nmax )) {
696696 // [-4, 3] ==> [-3, 4]
697697 const int tmp = nmin ;
@@ -762,7 +762,7 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
762762 .i = i ,
763763 };
764764 } else {
765- // stop when the inverse scale would result in clamping the max (FIXME: most important) value
765+ // stop when the inverse scale would result in clamping the most important value
766766 break ;
767767 }
768768 }
@@ -802,6 +802,182 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
802802 return negative_scale ? - scale : scale ;
803803}
804804
805+ // Very similar to make_qkxs_quants, but the sign of the scale is not assumed to be the sign of the absmax value.
806+ static float make_qkxss_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux ) {
807+ // start at zero
808+ nmin = MIN (0 , nmin );
809+ nmax = MAX (0 , nmax );
810+ float amax = 0.0f ;
811+ float min = 0.0f ;
812+ float max = 0.0f ;
813+ float w_amax = 0.0f ;
814+ int amax_i = -1 ;
815+ int w_amax_i = -1 ;
816+ for (int i = 0 ; i < n ; ++ i ) {
817+ const float w = weights ? weights [i ] : x [i ] * x [i ];
818+ const float ax = fabsf (x [i ]);
819+ const float wax = w * ax ;
820+ if (ax > amax ) { amax = ax ; amax_i = i ; }
821+ if (x [i ] > max ) { max = x [i ]; }
822+ if (x [i ] < min ) { min = x [i ]; }
823+ // Find the most important value
824+ if (wax > w_amax ) { w_amax = wax ; w_amax_i = i ; }
825+ }
826+
827+ if (amax < GROUP_MAX_EPS || amax_i < 0 || w_amax_i < 0 ) { // all zero
828+ for (int i = 0 ; i < n ; ++ i ) { L [i ] = 0 ; }
829+ return 0.0f ;
830+ }
831+
832+ // Use the side which will clamp first.
833+ // The first clamped value is the absmax at the end of the common range.
834+ // TODO: reduce the search space when one of the ranges is 0
835+ const int amax_range = MIN (- nmin , nmax );
836+ float sumlx_p = 0.0f ;
837+ float suml2_p = 0.0f ;
838+ float sumlx_n = 0.0f ;
839+ float suml2_n = 0.0f ;
840+ float scale = 0.0f ;
841+ float best = 0.0f ;
842+ float best_denom = 1.0f ;
843+ int best_i = -2 ; // not consecutive with 0..n_frac
844+ // Pre-calculate the half-point for the common range.
845+ // All smaller vectors have a representable vector with twice the values, and thus can be skipped.
846+ if (amax_range > 1 ) {
847+ const float iscale = ((float )(amax_range / 2 + 1 ))/amax ;
848+ for (int i = 0 ; i < n ; ++ i ) {
849+ const float w = weights ? weights [i ] : x [i ] * x [i ];
850+ int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
851+ Laux [i ] = l ;
852+ suml2_p += w * l * l ;
853+ sumlx_p += w * l * x [i ];
854+ }
855+ sumlx_n = - sumlx_p ;
856+ suml2_n = suml2_p ;
857+ const float current_p = sumlx_p * sumlx_p ;
858+ if (suml2_p > 0.0f && current_p * best_denom > best * suml2_p ) {
859+ best = current_p ;
860+ best_denom = suml2_p ;
861+ scale = sumlx_p / suml2_p ;
862+ for (int i = 0 ; i < n ; ++ i ) {
863+ L [i ] = Laux [i ];
864+ }
865+ best_i = -1 ; // right before 0 of the loop after sorting
866+ }
867+ } else {
868+ for (int i = 0 ; i < n ; ++ i ) {
869+ Laux [i ] = 0 ;
870+ }
871+ }
872+
873+ const int imax_range = MAX (nmax , - nmin );
874+ const int max_odd = 2 * (imax_range + 1 ) + 1 ;
875+ const float wmax = fabsf (x [w_amax_i ]);
876+ int n_frac = 0 ;
877+ for (int i = 0 ; i < n ; ++ i ) {
878+ // assuming nmin <= nmax
879+ const int odd_max = MAX (nmax , - nmin );
880+ const float v = fabsf (x [i ]);
881+ const float v_max_odd = v * max_odd ;
882+ for (int j = abs (Laux [i ]); j < odd_max ; ++ j ) {
883+ const float odd = 2 * j + 1 ;
884+ const float wmax_odd = wmax * odd ;
885+ if (wmax_odd < v_max_odd ) {
886+ Faux [n_frac ++ ] = (struct fraction ){
887+ .numer = v ,
888+ .denom = odd ,
889+ .i = i ,
890+ };
891+ } else {
892+ // stop when the inverse scale would result in clamping the most important value
893+ break ;
894+ }
895+ }
896+ }
897+
898+ qsort (Faux , n_frac , sizeof (struct fraction ), compare_fractions_desc );
899+
900+ const float max_common_odd = (MIN (nmax , - nmin ) * 2 ) + 1 ;
901+ const float max_odd_p = (nmax * 2 ) + 1 ;
902+ const float max_odd_n = (- nmin * 2 ) + 1 ;
903+
904+ for (int i = 0 ; i < n_frac ; ++ i ) {
905+ // maximize the weighted cosine similarity
906+ const int ii = Faux [i ].i ;
907+ const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
908+ const float lx = w * Faux [i ].numer ;
909+ const float odd = Faux [i ].denom ;
910+ const float l2 = w * odd ;
911+
912+ Laux [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
913+
914+ float sumlx = 0.0f ;
915+ float proj = 0.0f ;
916+ float norm = 0.0f ;
917+ if (odd < max_common_odd ) {
918+ sumlx_p += lx ;
919+ suml2_p += l2 ;
920+ sumlx_n -= lx ;
921+ suml2_n += l2 ;
922+
923+ sumlx = sumlx_p ;
924+ proj = sumlx_p * sumlx_p ;
925+ norm = suml2_p ;
926+
927+ // avoid double-copying Laux in a single iteration
928+ if (suml2_p != suml2_n && suml2_p * suml2_n > 0.0f ) {
929+ const float proj_n = sumlx_n * sumlx_n ;
930+ if (proj_n * norm > proj * suml2_n ) {
931+ sumlx = sumlx_n ;
932+ proj = proj_n ;
933+ norm = suml2_n ;
934+ }
935+ }
936+ } else if (x [ii ] < 0.0f ? odd < max_odd_n : odd < max_odd_p ) {
937+ sumlx_p += lx ;
938+ suml2_p += l2 ;
939+
940+ sumlx = sumlx_p ;
941+ proj = sumlx_p * sumlx_p ;
942+ norm = suml2_p ;
943+ } else {
944+ // outside the positive range means we're now into negatives
945+ sumlx_n -= lx ;
946+ suml2_n += l2 ;
947+
948+ sumlx = sumlx_n ;
949+ proj = sumlx_n * sumlx_n ;
950+ norm = suml2_n ;
951+ }
952+ if (norm > 0.0f && proj * best_denom > best * norm ) {
953+ best = proj ;
954+ best_denom = norm ;
955+ scale = sumlx / norm ;
956+ if (i == best_i + 1 ) {
957+ // reduce copies for consecutive bests
958+ L [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
959+ } else {
960+ for (int j = 0 ; j < n ; ++ j ) {
961+ L [j ] = Laux [j ];
962+ }
963+ }
964+ best_i = i ;
965+ }
966+ }
967+
968+ if (scale < 0.0f ) {
969+ for (int i = 0 ; i < n ; ++ i ) {
970+ L [i ] = MAX (nmin , MIN (- L [i ], nmax )) - nmin ;
971+ }
972+ } else {
973+ for (int i = 0 ; i < n ; ++ i ) {
974+ L [i ] = MAX (nmin , MIN (L [i ], nmax )) - nmin ;
975+ }
976+ }
977+
978+ return scale ;
979+ }
980+
805981// non-linear exhaustive search with cumulative sums
806982// Need Faux to have room for n*k fractions
807983static float make_qkxs_nl_quants (int n , int k , const float * restrict x , const float * restrict weights , const int8_t * restrict kvalues , uint8_t * restrict L , uint8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
@@ -874,6 +1050,7 @@ static float make_qkxs_nl_quants(int n, int k, const float * restrict x, const f
8741050 }
8751051
8761052 // Non-linear mappings are usually not symmetric, so try negating the scale
1053+ // This is the same as above, but keeping the old best if the new best is not better.
8771054 if (signed_scale ) {
8781055 for (int i = 0 ; i < n ; ++ i ) {
8791056 Laux [i ] = koff ;
@@ -1298,7 +1475,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
12981475 float amax = 0 ;
12991476 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
13001477 scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Laux , Faux , true);
1301- // scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
13021478 float scale = fabsf (scales [j ]);
13031479 if (scale > amax ) {
13041480 amax = scale ; max_scale = scales [j ];
@@ -1324,21 +1500,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
13241500 y [i ].d = GGML_FP32_TO_FP16 (0.f );
13251501 }
13261502
1327- // int8_t sc;
1328- // for (int j = 0; j < QK_K/16; ++j) {
1329- // sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1330- // sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1331- // float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1332- // if (!d) {
1333- // continue;
1334- // }
1335- // for (int ii = 0; ii < 16; ++ii) {
1336- // int l = nearest_int(x[16*j + ii]/d);
1337- // l = MAX(-4, MIN(3, l));
1338- // L[16*j + ii] = l + 4;
1339- // }
1340- // }
1341-
13421503 memset (y [i ].hmask , 0 , QK_K /8 );
13431504 // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
13441505 int m = 0 ;
@@ -1441,14 +1602,12 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
14411602 for (int l = 0 ; l < 16 ; ++ l ) sumw += weight [l ];
14421603 sw [j ] = sumw ;
14431604
1444- // scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
14451605 scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Laux , Faux , true);
14461606
14471607 }
14481608
14491609 memset (y [i ].scales , 0 , 12 );
14501610
1451- // float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
14521611 float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Laux , Faux , true);
14531612 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
14541613 int l = Ls [j ];
@@ -1462,21 +1621,6 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
14621621 }
14631622 y [i ].d = GGML_FP32_TO_FP16 (d_block );
14641623
1465- // int8_t sc;
1466- // for (int j = 0; j < QK_K/16; ++j) {
1467- // sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1468- // sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1469- // float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1470- // if (!d) {
1471- // continue;
1472- // }
1473- // for (int ii = 0; ii < 16; ++ii) {
1474- // int l = nearest_int(x[16*j + ii]/d);
1475- // l = MAX(-4, MIN(3, l));
1476- // L[16*j + ii] = l + 4;
1477- // }
1478- // }
1479-
14801624 memset (y [i ].hmask , 0 , QK_K /8 );
14811625 // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
14821626 int m = 0 ;
@@ -2526,7 +2670,7 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest
25262670 const float * xb = x + QK_K * ib ;
25272671 const float * qw = quant_weights + QK_K * ib ;
25282672 for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2529- float d = make_qkxs_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux , true );
2673+ float d = make_qkxss_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux );
25302674 y [ib ].d = GGML_FP32_TO_FP16 (d );
25312675
25322676 for (size_t j = 0 ; j < sizeof (y -> qs ); j += 32 ) {
0 commit comments