Skip to content

Commit 94f5d4a

Browse files
committed
ggml : remove Q5_1 bit shuffling (ARM NEON + reference)
1 parent b639b45 commit 94f5d4a

File tree

1 file changed

+66
-106
lines changed

1 file changed

+66
-106
lines changed

ggml.c

Lines changed: 66 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,7 @@ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block s
844844
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
845845
static const int qk = QK4_0;
846846

847-
assert(qk / 16 == 0);
848-
assert( k % qk == 0);
847+
assert(k % qk == 0);
849848

850849
const int nb = k / qk;
851850

@@ -866,20 +865,16 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
866865

867866
y[i].d = d;
868867

869-
uint64_t qs[QK4_0 / 16] = {0};
870-
871868
for (int l = 0; l < qk/2; ++l) {
872869
const float x0 = x[i*qk + 0 + l]*id;
873870
const float x1 = x[i*qk + qk/2 + l]*id;
874871

875-
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
876-
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
872+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
873+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
877874

878-
qs[l/8] |= xi0 << (8*(l & 7));
879-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
875+
y[i].qs[l] = xi0;
876+
y[i].qs[l] |= xi1 << 4;
880877
}
881-
882-
memcpy(y[i].qs, qs, qk/2);
883878
}
884879
}
885880

@@ -890,8 +885,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k
890885
static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
891886
const int qk = QK4_1;
892887

893-
assert(qk / 16 == 0);
894-
assert( k % qk == 0);
888+
assert(k % qk == 0);
895889

896890
const int nb = k / qk;
897891

@@ -912,20 +906,16 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
912906
y[i].d = d;
913907
y[i].m = min;
914908

915-
uint64_t qs[QK4_1 / 16] = {0};
916-
917909
for (int l = 0; l < qk/2; ++l) {
918910
const float x0 = (x[0 + l] - min)*id;
919911
const float x1 = (x[qk/2 + l] - min)*id;
920912

921-
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
922-
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
913+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
914+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
923915

924-
qs[l/8] |= xi0 << (8*(l & 7));
925-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
916+
y[i].qs[l] = xi0;
917+
y[i].qs[l] |= xi1 << 4;
926918
}
927-
928-
memcpy(y[i].qs, qs, qk/2);
929919
}
930920
}
931921

@@ -937,8 +927,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
937927
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
938928
static const int qk = QK4_2;
939929

940-
assert(qk / 16 == 0);
941-
assert( k % qk == 0);
930+
assert(k % qk == 0);
942931

943932
const int nb = k / qk;
944933

@@ -983,8 +972,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k
983972
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
984973
static const int qk = QK5_0;
985974

986-
assert(qk / 16 == 0);
987-
assert( k % qk == 0);
975+
assert(k % qk == 0);
988976

989977
const int nb = k / qk;
990978

@@ -1006,24 +994,21 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1006994
y[i].d = d;
1007995

1008996
uint32_t qh = 0;
1009-
uint64_t qs[QK5_0 / 16] = {0};
1010997

1011998
for (int l = 0; l < qk/2; ++l) {
1012999
const float x0 = x[i*qk + 0 + l]*id;
10131000
const float x1 = x[i*qk + qk/2 + l]*id;
10141001

1015-
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1016-
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
1002+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1003+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
10171004

1018-
qs[l/8] |= xi0 << (8*(l & 7));
1019-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
1005+
y[i].qs[l] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
10201006

10211007
// get the 5-th bit and store it in qh at the right position
10221008
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
10231009
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10241010
}
10251011

1026-
memcpy( y[i].qs, qs, qk/2);
10271012
memcpy(&y[i].qh, &qh, sizeof(qh));
10281013
}
10291014
}
@@ -1033,50 +1018,50 @@ static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k
10331018
}
10341019

10351020
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1036-
assert(k % QK5_1 == 0);
1037-
const int nb = k / QK5_1;
1021+
const int qk = QK5_1;
1022+
1023+
assert(k % qk == 0);
1024+
1025+
const int nb = k / qk;
10381026

10391027
for (int i = 0; i < nb; i++) {
10401028
float min = FLT_MAX;
10411029
float max = -FLT_MAX;
10421030

1043-
for (int l = 0; l < QK5_1; l++) {
1044-
const float v = x[i*QK5_1 + l];
1031+
for (int l = 0; l < qk; l++) {
1032+
const float v = x[i*qk + l];
1033+
10451034
if (v < min) min = v;
10461035
if (v > max) max = v;
10471036
}
10481037

1049-
const float d = (max - min) / ((1 << 5) - 1);
1038+
const float d = (max - min) / ((1 << 5) - 1);
10501039
const float id = d ? 1.0f/d : 0.0f;
10511040

10521041
y[i].d = GGML_FP32_TO_FP16(d);
10531042
y[i].m = GGML_FP32_TO_FP16(min);
10541043

10551044
uint32_t qh = 0;
10561045

1057-
for (int l = 0; l < QK5_1; l += 2) {
1058-
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1059-
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1046+
for (int l = 0; l < qk/2; ++l) {
1047+
const float x0 = (x[i*qk + 0 + l] - min)*id;
1048+
const float x1 = (x[i*qk + qk/2 + l] - min)*id;
10601049

1061-
const uint32_t vi0 = (int) (v0 + 0.5f);
1062-
const uint32_t vi1 = (int) (v1 + 0.5f);
1050+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
1051+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
10631052

1064-
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1053+
y[i].qs[l] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
10651054

10661055
// get the 5-th bit and store it in qh at the right position
1067-
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1068-
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1056+
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
1057+
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10691058
}
10701059

10711060
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
10721061
}
10731062
}
10741063

1075-
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1076-
assert(k % QK5_1 == 0);
1077-
1078-
block_q5_1 * restrict y = vy;
1079-
1064+
static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
10801065
quantize_row_q5_1_reference(x, y, k);
10811066
}
10821067

@@ -1316,8 +1301,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
13161301
static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
13171302
static const int qk = QK4_0;
13181303

1319-
assert(qk / 16 == 0);
1320-
assert( k % qk == 0);
1304+
assert(k % qk == 0);
13211305

13221306
const int nb = k / qk;
13231307

@@ -1337,8 +1321,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
13371321
static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
13381322
static const int qk = QK4_1;
13391323

1340-
assert(qk / 16 == 0);
1341-
assert( k % qk == 0);
1324+
assert(k % qk == 0);
13421325

13431326
const int nb = k / qk;
13441327

@@ -1360,8 +1343,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13601343
// BORKEN !!!
13611344
static const int qk = QK4_2;
13621345

1363-
assert(qk / 16 == 0);
1364-
assert( k % qk == 0);
1346+
assert(k % qk == 0);
13651347

13661348
const int nb = k / qk;
13671349

@@ -1381,8 +1363,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13811363
static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
13821364
static const int qk = QK4_0;
13831365

1384-
assert(qk / 16 == 0);
1385-
assert( k % qk == 0);
1366+
assert(k % qk == 0);
13861367

13871368
const int nb = k / qk;
13881369

@@ -1405,39 +1386,29 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
14051386
}
14061387
}
14071388

1408-
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1409-
assert(k % QK5_1 == 0);
1410-
const int nb = k / QK5_1;
1389+
static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1390+
static const int qk = QK5_1;
14111391

1412-
const block_q5_1 * restrict x = vx;
1392+
assert(k % qk == 0);
1393+
1394+
const int nb = k / qk;
14131395

14141396
for (int i = 0; i < nb; i++) {
14151397
const float d = GGML_FP16_TO_FP32(x[i].d);
14161398
const float m = GGML_FP16_TO_FP32(x[i].m);
14171399

1418-
const uint8_t * restrict pp = x[i].qs;
1419-
14201400
uint32_t qh;
14211401
memcpy(&qh, x[i].qh, sizeof(qh));
14221402

1423-
for (int l = 0; l < QK5_1; l += 2) {
1424-
const uint8_t vi = pp[l/2];
1425-
1426-
// extract the 5-th bit from qh
1427-
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
1428-
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
1429-
1430-
const uint8_t vi0 = (vi & 0x0F) | vh0;
1431-
const uint8_t vi1 = (vi >> 4) | vh1;
1432-
1433-
const float v0 = vi0*d + m;
1434-
const float v1 = vi1*d + m;
1403+
for (int j = 0; j < qk/2; ++j) {
1404+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
1405+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
14351406

1436-
y[i*QK5_1 + l + 0] = v0;
1437-
y[i*QK5_1 + l + 1] = v1;
1407+
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
1408+
const int x1 = (x[i].qs[j] >> 4) | xh_1;
14381409

1439-
assert(!isnan(y[i*QK5_1 + l + 0]));
1440-
assert(!isnan(y[i*QK5_1 + l + 1]));
1410+
y[i*qk + j + 0 ] = x0*d + m;
1411+
y[i*qk + j + qk/2] = x1*d + m;
14411412
}
14421413
}
14431414
}
@@ -1500,7 +1471,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
15001471
.vec_dot_type = GGML_TYPE_Q8_0,
15011472
},
15021473
[GGML_TYPE_Q5_1] = {
1503-
.dequantize_row_q = dequantize_row_q5_1,
1474+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_1,
15041475
.quantize_row_q = quantize_row_q5_1,
15051476
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
15061477
.quantize_row_q_dot = quantize_row_q8_1,
@@ -2748,11 +2719,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
27482719
}
27492720

27502721
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2751-
const int nb = n / QK8_1;
2722+
const int qk = QK8_1;
2723+
const int nb = n / qk;
27522724

2753-
assert(n % QK8_1 == 0);
2725+
assert(n % qk == 0);
27542726
assert(nb % 2 == 0);
2755-
assert(QK8_1 == QK5_1);
2727+
assert(qk == QK5_1);
27562728

27572729
const block_q5_1 * restrict x = vx;
27582730
const block_q8_1 * restrict y = vy;
@@ -2788,13 +2760,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
27882760
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
27892761
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
27902762

2791-
// interleave
2792-
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
2793-
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2794-
27952763
// add
2796-
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
2797-
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
2764+
const int8x16_t v0lf = vorrq_s8(v0l, qhl);
2765+
const int8x16_t v0hf = vorrq_s8(v0h, qhh);
27982766

27992767
// load y
28002768
const int8x16_t v1l = vld1q_s8(y0->qs);
@@ -2917,36 +2885,28 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29172885

29182886
*s = hsum_float_8(acc) + summs;
29192887
#else
2888+
// scalar
29202889
float sumf = 0.0;
29212890

29222891
for (int i = 0; i < nb; i++) {
2923-
const uint8_t * restrict x0 = x[i].qs;
2924-
const int8_t * restrict y0 = y[i].qs;
2892+
const int8_t * py = y[i].qs;
29252893

29262894
uint32_t qh;
29272895
memcpy(&qh, x[i].qh, sizeof(qh));
29282896

2929-
const float d = GGML_FP16_TO_FP32(x[i].d);
2930-
const float m = GGML_FP16_TO_FP32(x[i].m);
2931-
2932-
int sxy = 0;
2933-
2934-
for (int j = 0; j < QK8_1/2; j++) {
2935-
const uint8_t v0 = x0[j];
2936-
2937-
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
2938-
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
2897+
int sumi = 0;
29392898

2940-
const int x0_0 = (v0 & 0x0F) | x0_0h;
2941-
const int x1_0 = (v0 >> 4) | x1_0h;
2899+
for (int j = 0; j < qk/2; ++j) {
2900+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
2901+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
29422902

2943-
const int y0_0 = y0[2*j + 0];
2944-
const int y1_0 = y0[2*j + 1];
2903+
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
2904+
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
29452905

2946-
sxy += x0_0*y0_0 + x1_0*y1_0;
2906+
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
29472907
}
29482908

2949-
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2909+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
29502910
}
29512911

29522912
*s = sumf;

0 commit comments

Comments
 (0)