Skip to content

Commit 58175e3

Browse files
committed
Rearrange Q4_1 quantization to work for multipart models. (Fix #152)
1 parent 5e66b6b commit 58175e3

File tree

2 files changed

+49
-38
lines changed

2 files changed

+49
-38
lines changed

ggml.c

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607607
assert(k % QK == 0);
608608

609609
const int nb = k / QK;
610+
const size_t bs = 2*sizeof(float) + QK/2;
610611

611-
float * restrict pm = (float *) (y);
612-
float * restrict pd = (float *) (pm + nb);
613-
uint8_t * restrict pb = (uint8_t *) (pd + nb);
612+
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
613+
uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
614+
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
614615

615616
uint8_t pp[QK/2];
616617

@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627628
const float d = (max - min) / ((1 << 4) - 1);
628629
const float id = d ? 1.0f/d : 0.0f;
629630

630-
pm[i] = min;
631-
pd[i] = d;
631+
*(float *)pm = min;
632+
*(float *)pd = d;
633+
pm += bs;
634+
pd += bs;
632635

633636
for (int l = 0; l < QK; l += 2) {
634637
const float v0 = (x[i*QK + l + 0] - min)*id;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643646
pp[l/2] = vi0 | (vi1 << 4);
644647
}
645648

646-
memcpy(pb + i*QK/2, pp, sizeof(pp));
649+
memcpy(pb, pp, sizeof(pp));
650+
pb += bs;
647651
}
648652
}
649653

@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687691
assert(k % QK == 0);
688692

689693
const int nb = k / QK;
694+
const size_t bs = 2*sizeof(float) + QK/2;
690695

691-
const float * restrict pm = (const float *) (x);
692-
const float * restrict pd = (const float *) (pm + nb);
693-
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
696+
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
697+
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
698+
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
694699

695700
for (int i = 0; i < nb; i++) {
696-
const float m = pm[i];
697-
const float d = pd[i];
701+
const float d = *(const float *) (pd + i*bs);
702+
const float m = *(const float *) (pm + i*bs);
698703

699-
const uint8_t * restrict pp = pb + i*QK/2;
704+
const uint8_t * restrict pp = pb + i*bs;
700705

701706
for (int l = 0; l < QK; l += 2) {
702707
const uint8_t vi = pp[l/2];
@@ -1553,14 +1558,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
15531558
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
15541559
const int nb = n / QK;
15551560

1556-
const float * restrict pm0 = (const float *) x;
1557-
const float * restrict pm1 = (const float *) y;
1561+
const size_t bs = 2*sizeof(float) + QK/2;
1562+
1563+
const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
1564+
const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
15581565

1559-
const float * restrict pd0 = (const float *) (pm0 + nb);
1560-
const float * restrict pd1 = (const float *) (pm1 + nb);
1566+
const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
1567+
const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
15611568

1562-
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
1563-
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
1569+
const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
1570+
const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
15641571

15651572
float sumf = 0.0;
15661573

@@ -1573,14 +1580,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
15731580

15741581
// Main loop
15751582
for (int i = 0; i < nb; ++i) {
1576-
const float * m0 = (const float *) (pm0 + i);
1577-
const float * m1 = (const float *) (pm1 + i);
1583+
const float * m0 = (const float *) (pm0 + i*bs);
1584+
const float * m1 = (const float *) (pm1 + i*bs);
15781585

1579-
const float * d0 = (const float *) (pd0 + i);
1580-
const float * d1 = (const float *) (pd1 + i);
1586+
const float * d0 = (const float *) (pd0 + i*bs);
1587+
const float * d1 = (const float *) (pd1 + i*bs);
15811588

1582-
const uint8_t * restrict p0 = pb0 + i*QK/2;
1583-
const uint8_t * restrict p1 = pb1 + i*QK/2;
1589+
const uint8_t * restrict p0 = pb0 + i*bs;
1590+
const uint8_t * restrict p1 = pb1 + i*bs;
15841591

15851592
const __m256 d0v = _mm256_broadcast_ss( d0 );
15861593
const __m256 d1v = _mm256_broadcast_ss( d1 );
@@ -1646,14 +1653,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
16461653
#else
16471654
// scalar
16481655
for (int i = 0; i < nb; i++) {
1649-
const float m0 = pm0[i];
1650-
const float m1 = pm1[i];
1656+
const float * m0 = (const float *) (pm0 + i*bs);
1657+
const float * m1 = (const float *) (pm1 + i*bs);
16511658

1652-
const float d0 = pd0[i];
1653-
const float d1 = pd1[i];
1659+
const float * d0 = (const float *) (pd0 + i*bs);
1660+
const float * d1 = (const float *) (pd1 + i*bs);
16541661

1655-
const uint8_t * restrict p0 = pb0 + i*QK/2;
1656-
const uint8_t * restrict p1 = pb1 + i*QK/2;
1662+
const uint8_t * restrict p0 = pb0 + i*bs;
1663+
const uint8_t * restrict p1 = pb1 + i*bs;
16571664

16581665
for (int j = 0; j < QK/2; j++) {
16591666
const uint8_t v0 = p0[j];

utils.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
486486

487487
size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
488488
const int nb = k / qk;
489-
const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2);
489+
const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
490+
const size_t row_size = nb*bs;
490491

491492
assert(k % qk == 0);
492493

@@ -495,10 +496,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
495496

496497
char * pdst = (char *) dst;
497498

498-
for (int j = 0; j < n; j += k) {
499-
float * pm = (float *) (pdst + (j/k)*row_size);
500-
float * pd = (float *) (pm + nb);
501-
uint8_t * pb = (uint8_t *) (pd + nb);
499+
for (int j = 0; j < n; j += k) {
500+
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
501+
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
502+
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
502503

503504
//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
504505

@@ -516,8 +517,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
516517
const float d = (max - min) / ((1 << 4) - 1);
517518
const float id = d ? 1.0f/d : 0.0f;
518519

519-
pm[i] = min;
520-
pd[i] = d;
520+
*(float *) pd = d;
521+
*(float *) pm = min;
522+
pd += bs;
523+
pm += bs;
521524

522525
for (int l = 0; l < qk; l += 2) {
523526
const float v0 = (src[j + i*qk + l + 0] - min)*id;
@@ -535,7 +538,8 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
535538
pp[l/2] = vi0 | (vi1 << 4);
536539
}
537540

538-
memcpy(pb + i*qk/2, pp, pp_size);
541+
memcpy(pb, pp, pp_size);
542+
pb += bs;
539543
}
540544
}
541545
}

0 commit comments

Comments
 (0)