Skip to content

Commit 6eec060

Browse files
committed
Q4_2 quantization with rmse-optimized scale and quants
For quantize-stats we get q4_2: rmse 0.00159301, maxerr 0.17480469, 95pct<0.0030, median<0.0012 For 7B perplexity with BLAS enabled we get 6.2038 after 655 chunks. Quantization is slow (~90 seconds on my Mac for 7B) as not multi-threaded as in PR #896.
1 parent 8944a13 commit 6eec060

File tree

1 file changed

+87
-3
lines changed

1 file changed

+87
-3
lines changed

ggml.c

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <inttypes.h>
2020
#include <stdio.h>
2121
#include <float.h>
22+
#include <limits.h>
2223

2324
// if C99 - static_assert is noop
2425
// ref: https://stackoverflow.com/a/53923785/4039976
@@ -1123,12 +1124,94 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
11231124
}
11241125
}
11251126

1127+
inline int nearestInt(float fval) {
1128+
assert(fval <= 4194303.f);
1129+
float val = fval + 12582912.f;
1130+
int i; memcpy(&i, &val, sizeof(int));
1131+
return (i & 0x007fffff) - 0x00400000;
1132+
}
1133+
1134+
static float kQuantizeWithBounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1135+
const float * restrict candidates, int8_t * restrict L) {
1136+
assert (nmin >= INT8_MIN);
1137+
assert (nmax <= INT8_MAX);
1138+
float amax = 0;
1139+
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1140+
if (!amax) { // all zero
1141+
for (int i=0; i<n; ++i) L[i] = 0;
1142+
return 1.f;
1143+
}
1144+
float best = 0, bestScale = 0;
1145+
for (int si=0; si<nCandidates; ++si) {
1146+
float iscale = candidates[si]/amax;
1147+
float sumlxP = 0; int suml2P = 0;
1148+
float sumlxM = 0; int suml2M = 0;
1149+
for (int i=0; i<n; ++i) {
1150+
int l = nearestInt(iscale*X[i]);
1151+
int lp = MAX(nmin, MIN(nmax, +l));
1152+
int lm = MAX(nmin, MIN(nmax, -l));
1153+
sumlxP += X[i]*lp; suml2P += lp*lp;
1154+
sumlxM += X[i]*lm; suml2M += lm*lm;
1155+
}
1156+
float sumlxP2 = sumlxP*sumlxP;
1157+
float sumlxM2 = sumlxM*sumlxM;
1158+
if (sumlxP2*suml2M > sumlxM2*suml2P) {
1159+
if (sumlxP2 > best*suml2P) {
1160+
best = sumlxP2/suml2P; bestScale = iscale;
1161+
}
1162+
} else {
1163+
if (sumlxM2 > best*suml2M) {
1164+
best = sumlxM2/suml2M; bestScale = -iscale;
1165+
}
1166+
}
1167+
}
1168+
float sumlx = 0; int suml2 = 0;
1169+
for (int i=0; i<n; ++i) {
1170+
int l = nearestInt(bestScale*X[i]);
1171+
l = MAX(nmin, MIN(nmax, l));
1172+
sumlx += X[i]*l; suml2 += l*l;
1173+
L[i] = l;
1174+
}
1175+
float scale = sumlx/suml2;
1176+
return scale;
1177+
}
1178+
1179+
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1180+
#define kCandiateCount 8
1181+
static const float candidates[kCandiateCount] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1182+
assert(k % QK4_2 == 0);
1183+
1184+
int8_t L[QK4_2];
1185+
1186+
const int nb = k / QK4_2;
1187+
1188+
for (int i = 0; i < nb; i++) {
1189+
1190+
float scale = kQuantizeWithBounds(QK4_2, -8, 7, x, kCandiateCount, candidates, L);
1191+
y[i].d = GGML_FP32_TO_FP16(scale);
1192+
1193+
for (int l = 0; l < QK4_2; l += 2) {
1194+
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1195+
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1196+
1197+
assert(vi0 < 16);
1198+
assert(vi1 < 16);
1199+
1200+
y[i].qs[l/2] = vi0 | (vi1 << 4);
1201+
}
1202+
1203+
x += QK4_2;
1204+
}
1205+
}
1206+
11261207
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
11271208
assert(k % QK4_2 == 0);
11281209

11291210
block_q4_2 * restrict y = vy;
11301211

1131-
quantize_row_q4_2_reference(x, y, k);
1212+
//quantize_row_q4_2_reference(x, y, k);
1213+
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1214+
quantize_row_q4_2_rmse(x, y, k);
11321215
}
11331216

11341217
// reference implementation for deterministic creation of model files
@@ -1558,7 +1641,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
15581641
[GGML_TYPE_Q4_2] = {
15591642
.dequantize_row_q = dequantize_row_q4_2,
15601643
.quantize_row_q = quantize_row_q4_2,
1561-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
1644+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
15621645
.quantize_row_q_dot = quantize_row_q8_0,
15631646
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
15641647
},
@@ -12298,7 +12381,8 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
1229812381
for (int j = 0; j < n; j += k) {
1229912382
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
1230012383

12301-
quantize_row_q4_2_reference(src + j, y, k);
12384+
//quantize_row_q4_2_reference(src + j, y, k);
12385+
quantize_row_q4_2_rmse(src + j, y, k);
1230212386

1230312387
for (int i = 0; i < nb; i++) {
1230412388
for (int l = 0; l < QK4_2; l += 2) {

0 commit comments

Comments
 (0)