Skip to content

Commit b4dfdf7

Browse files
committed
Deduplicate q4 quantization functions
1 parent ae44e23 commit b4dfdf7

File tree

1 file changed

+63
-104
lines changed

1 file changed

+63
-104
lines changed

ggml.c

Lines changed: 63 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -400,16 +400,63 @@ static inline __m128i packNibbles( __m256i bytes )
400400
// method 5
401401
// blocks of QK elements
402402
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
403+
404+
// reference implementation for deterministic creation of model files
405+
static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) {
406+
assert(k % QK == 0);
407+
const int nb = k / QK;
408+
409+
const size_t bs = sizeof(float) + QK/2;
410+
411+
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
412+
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
413+
414+
uint8_t pp[QK/2];
415+
416+
for (int i = 0; i < nb; i++) {
417+
float amax = 0.0f; // absolute max
418+
419+
for (int l = 0; l < QK; l++) {
420+
const float v = x[i*QK + l];
421+
amax = MAX(amax, fabsf(v));
422+
}
423+
424+
const float d = amax / ((1 << 3) - 1);
425+
const float id = d ? 1.0f/d : 0.0f;
426+
427+
*(float *)pd = d;
428+
pd += bs;
429+
430+
for (int l = 0; l < QK; l += 2) {
431+
const float v0 = x[i*QK + l + 0]*id;
432+
const float v1 = x[i*QK + l + 1]*id;
433+
434+
const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
435+
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
436+
437+
assert(vi0 >= 0 && vi0 < 16);
438+
assert(vi1 >= 0 && vi1 < 16);
439+
440+
pp[l/2] = vi0 | (vi1 << 4);
441+
}
442+
443+
memcpy(pb, pp, sizeof(pp));
444+
pb += bs;
445+
}
446+
}
447+
403448
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
404449
assert(k % QK == 0);
405450

451+
#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__)
406452
const int nb = k / QK;
407453
const size_t bs = sizeof(float) + QK/2;
408454

409455
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
410456
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
411457

412458
uint8_t pp[QK/2];
459+
#endif
413460

414461
#if __ARM_NEON
415462
#if QK == 32
@@ -566,36 +613,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
566613
#endif
567614
#else
568615
// scalar
569-
for (int i = 0; i < nb; i++) {
570-
float amax = 0.0f; // absolute max
571-
572-
for (int l = 0; l < QK; l++) {
573-
const float v = x[i*QK + l];
574-
amax = MAX(amax, fabsf(v));
575-
}
576-
577-
const float d = amax / ((1 << 3) - 1);
578-
const float id = d ? 1.0f/d : 0.0f;
579-
580-
*(float *)pd = d;
581-
pd += bs;
582-
583-
for (int l = 0; l < QK; l += 2) {
584-
const float v0 = x[i*QK + l + 0]*id;
585-
const float v1 = x[i*QK + l + 1]*id;
586-
587-
const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
588-
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
589-
590-
assert(vi0 >= 0 && vi0 < 16);
591-
assert(vi1 >= 0 && vi1 < 16);
592-
593-
pp[l/2] = vi0 | (vi1 << 4);
594-
}
595-
596-
memcpy(pb, pp, sizeof(pp));
597-
pb += bs;
598-
}
616+
quantize_row_q4_0_reference(x, y, k);
599617
#endif
600618
}
601619

@@ -10709,49 +10727,23 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
1070910727

1071010728
assert(k % qk == 0);
1071110729

10712-
const size_t pp_size = qk / 2;
10713-
uint8_t * pp = (uint8_t *) alloca(pp_size);
10714-
1071510730
char * pdst = (char *) dst;
1071610731

1071710732
for (int j = 0; j < n; j += k) {
1071810733
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
1071910734
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
1072010735

10721-
for (int i = 0; i < nb; i++) {
10722-
float amax = 0.0f; // absolute max
10723-
10724-
{
10725-
for (int l = 0; l < qk; l++) {
10726-
const float v = src[j + i*qk + l];
10727-
amax = MAX(amax, fabsf(v));
10728-
}
10729-
10730-
const float d = amax / ((1 << 3) - 1);
10731-
const float id = d ? 1.0f/d : 0.0f;
10732-
10733-
*(float *) pd = d;
10734-
pd += bs;
10736+
quantize_row_q4_0_reference(src + j, pd, k);
1073510737

10736-
for (int l = 0; l < qk; l += 2) {
10737-
const float v0 = (src[j + i*qk + l + 0])*id;
10738-
const float v1 = (src[j + i*qk + l + 1])*id;
10739-
10740-
const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
10741-
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
10742-
10743-
assert(vi0 >= 0 && vi0 < 16);
10744-
assert(vi1 >= 0 && vi1 < 16);
10745-
10746-
hist[vi0]++;
10747-
hist[vi1]++;
10748-
10749-
pp[l/2] = vi0 | (vi1 << 4);
10750-
}
10738+
for (int i = 0; i < nb; i++) {
10739+
for (int l = 0; l < qk; l += 2) {
10740+
const uint8_t vi0 = pb[l/2] & 0xF;
10741+
const uint8_t vi1 = pb[l/2] >> 4;
1075110742

10752-
memcpy(pb, pp, pp_size);
10753-
pb += bs;
10743+
hist[vi0]++;
10744+
hist[vi1]++;
1075410745
}
10746+
pb += bs;
1075510747
}
1075610748
}
1075710749

@@ -10765,56 +10757,23 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
1076510757

1076610758
assert(k % qk == 0);
1076710759

10768-
const size_t pp_size = qk / 2;
10769-
uint8_t * pp = (uint8_t *) alloca(pp_size);
10770-
1077110760
char * pdst = (char *) dst;
1077210761

1077310762
for (int j = 0; j < n; j += k) {
1077410763
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
10775-
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
1077610764
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
1077710765

10778-
//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
10766+
quantize_row_q4_1(src + j, pd, k);
1077910767

1078010768
for (int i = 0; i < nb; i++) {
10781-
float min = FLT_MAX;
10782-
float max = -FLT_MAX;
10783-
10784-
{
10785-
for (int l = 0; l < qk; l++) {
10786-
const float v = src[j + i*qk + l];
10787-
if (v < min) min = v;
10788-
if (v > max) max = v;
10789-
}
10790-
10791-
const float d = (max - min) / ((1 << 4) - 1);
10792-
const float id = d ? 1.0f/d : 0.0f;
10793-
10794-
*(float *) pd = d;
10795-
*(float *) pm = min;
10796-
pd += bs;
10797-
pm += bs;
10798-
10799-
for (int l = 0; l < qk; l += 2) {
10800-
const float v0 = (src[j + i*qk + l + 0] - min)*id;
10801-
const float v1 = (src[j + i*qk + l + 1] - min)*id;
10802-
10803-
const uint8_t vi0 = round(v0);
10804-
const uint8_t vi1 = round(v1);
10805-
10806-
assert(vi0 >= 0 && vi0 < 16);
10807-
assert(vi1 >= 0 && vi1 < 16);
10808-
10809-
hist[vi0]++;
10810-
hist[vi1]++;
10811-
10812-
pp[l/2] = vi0 | (vi1 << 4);
10813-
}
10769+
for (int l = 0; l < qk; l += 2) {
10770+
const uint8_t vi0 = pb[l/2] & 0xF;
10771+
const uint8_t vi1 = pb[l/2] >> 4;
1081410772

10815-
memcpy(pb, pp, pp_size);
10816-
pb += bs;
10773+
hist[vi0]++;
10774+
hist[vi1]++;
1081710775
}
10776+
pb += bs;
1081810777
}
1081910778
}
1082010779

0 commit comments

Comments
 (0)