Skip to content

Commit 2bfa1fe

Browse files
authored
ggml : AVX2 optimizations for Q5_0, Q5_1 (#1195)
1 parent 982bfce commit 2bfa1fe

File tree

1 file changed

+55
-50
lines changed

1 file changed

+55
-50
lines changed

ggml.c

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,18 @@ static ggml_fp16_t table_exp_f16[1 << 16];
328328
// precomputed f32 table for f16 (256 KB)
329329
static float table_f32_f16[1 << 16];
330330

331-
// precomputed table for expanding 8bits to 8 bytes (shl 4)
332-
static uint64_t table_b2b[1 << 8];
331+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
332+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
333+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
334+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
335+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
336+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
337+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
338+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
339+
340+
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
341+
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
342+
static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) };
333343

334344
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
335345
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
@@ -688,7 +698,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
688698
typedef struct {
689699
ggml_fp16_t d; // delta
690700
ggml_fp16_t m; // min
691-
uint32_t qh; // 5-th bit of quants
701+
uint8_t qh[4]; // 5-th bit of quants
692702
uint8_t qs[QK5_1 / 2]; // nibbles / quants
693703
} block_q5_1;
694704
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
@@ -1376,7 +1386,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
13761386

13771387
y[i].d = GGML_FP32_TO_FP16(d);
13781388
y[i].m = GGML_FP32_TO_FP16(min);
1379-
y[i].qh = 0;
1389+
1390+
uint32_t qh = 0;
13801391

13811392
for (int l = 0; l < QK5_1; l += 2) {
13821393
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
@@ -1388,9 +1399,11 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
13881399
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
13891400

13901401
// get the 5-th bit and store it in qh at the right position
1391-
y[i].qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1392-
y[i].qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1402+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1403+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
13931404
}
1405+
1406+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
13941407
}
13951408
}
13961409

@@ -1966,7 +1979,8 @@ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, in
19661979

19671980
const uint8_t * restrict pp = x[i].qs;
19681981

1969-
const uint32_t qh = x[i].qh;
1982+
uint32_t qh;
1983+
memcpy(&qh, x[i].qh, sizeof(qh));
19701984

19711985
for (int l = 0; l < QK5_1; l += 2) {
19721986
const uint8_t vi = pp[l/2];
@@ -3297,10 +3311,10 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
32973311
uint32_t qh;
32983312
memcpy(&qh, x0->qh, sizeof(qh));
32993313

3300-
tmp[0] = table_b2b[(qh >> 0) & 0xFF];
3301-
tmp[1] = table_b2b[(qh >> 8) & 0xFF];
3302-
tmp[2] = table_b2b[(qh >> 16) & 0xFF];
3303-
tmp[3] = table_b2b[(qh >> 24) ];
3314+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3315+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3316+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3317+
tmp[3] = table_b2b_u[(qh >> 24) ];
33043318

33053319
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
33063320
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
@@ -3350,17 +3364,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
33503364
// Main loop
33513365
for (int i = 0; i < nb; i++) {
33523366
/* Compute combined scale for the block */
3353-
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3354-
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3355-
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3356-
3357-
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3358-
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3359-
__m256i bx = _mm256_set_m128i(bx1, bx0);
3367+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
33603368

3361-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3362-
const __m256i off = _mm256_set1_epi8(8);
3363-
bx = _mm256_sub_epi8(bx, off);
3369+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
3370+
const __m256i bxhi = _mm256_set_epi64x(
3371+
table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]],
3372+
table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]);
3373+
bx = _mm256_or_si256(bx, bxhi);
33643374

33653375
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
33663376

@@ -3379,7 +3389,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
33793389
const int8_t * restrict y0 = y[i].qs;
33803390

33813391
uint32_t qh;
3382-
memcpy(&qh, x0->qh, sizeof(qh));
3392+
memcpy(&qh, x[i].qh, sizeof(qh));
33833393

33843394
const float d = GGML_FP16_TO_FP32(x[i].d);
33853395

@@ -3430,12 +3440,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
34303440
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
34313441

34323442
// extract the 5th bit
3433-
const uint32_t qh = x0->qh;
3443+
uint32_t qh;
3444+
memcpy(&qh, x0->qh, sizeof(qh));
34343445

3435-
tmp[0] = table_b2b[(qh >> 0) & 0xFF];
3436-
tmp[1] = table_b2b[(qh >> 8) & 0xFF];
3437-
tmp[2] = table_b2b[(qh >> 16) & 0xFF];
3438-
tmp[3] = table_b2b[(qh >> 24) ];
3446+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3447+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3448+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3449+
tmp[3] = table_b2b_u[(qh >> 24) ];
34393450

34403451
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
34413452
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
@@ -3485,16 +3496,15 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
34853496

34863497
// Main loop
34873498
for (int i = 0; i < nb; i++) {
3488-
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3489-
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3490-
const __m256 dx = _mm256_set_m128(d1, d0);
3499+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
34913500

3492-
summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
3493-
+ GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
3501+
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
34943502

3495-
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3496-
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3497-
const __m256i bx = _mm256_set_m128i(bx1, bx0);
3503+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
3504+
const __m256i bxhi = _mm256_set_epi64x(
3505+
table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]],
3506+
table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]);
3507+
bx = _mm256_or_si256(bx, bxhi);
34983508

34993509
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
35003510
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
@@ -3512,7 +3522,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
35123522
const uint8_t * restrict x0 = x[i].qs;
35133523
const int8_t * restrict y0 = y[i].qs;
35143524

3515-
const uint32_t qh = x[i].qh;
3525+
uint32_t qh;
3526+
memcpy(&qh, x[i].qh, sizeof(qh));
35163527

35173528
const float d = GGML_FP16_TO_FP32(x[i].d);
35183529
const float m = GGML_FP16_TO_FP32(x[i].m);
@@ -4297,15 +4308,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
42974308
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
42984309
}
42994310

4300-
for (int i = 0; i < 256; ++i) {
4301-
table_b2b[i] = 0;
4302-
for (int b = 0; b < 8; ++b) {
4303-
table_b2b[i] |= ((uint64_t)(((i >> b) & 0x01) << 4)) << (8*b);
4304-
}
4305-
4306-
//printf("%3d %016llx\n", i, table_b2b[i]);
4307-
}
4308-
43094311
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
43104312

43114313
GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
@@ -12855,10 +12857,10 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
1285512857
quantize_row_q5_0_reference(src + j, y, k);
1285612858

1285712859
for (int i = 0; i < nb; i++) {
12858-
for (int l = 0; l < QK5_0; l += 2) {
12859-
uint32_t qh;
12860-
memcpy(&qh, &y[i].qh, sizeof(qh));
12860+
uint32_t qh;
12861+
memcpy(&qh, &y[i].qh, sizeof(qh));
1286112862

12863+
for (int l = 0; l < QK5_0; l += 2) {
1286212864
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1286312865
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1286412866

@@ -12885,9 +12887,12 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t *
1288512887
quantize_row_q5_1_reference(src + j, y, k);
1288612888

1288712889
for (int i = 0; i < nb; i++) {
12890+
uint32_t qh;
12891+
memcpy(&qh, &y[i].qh, sizeof(qh));
12892+
1288812893
for (int l = 0; l < QK5_1; l += 2) {
12889-
const uint8_t vh0 = ((y[i].qh & (1 << (l + 0))) >> (l + 0)) << 4;
12890-
const uint8_t vh1 = ((y[i].qh & (1 << (l + 1))) >> (l + 1)) << 4;
12894+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12895+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1289112896

1289212897
// cast to 16 bins
1289312898
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;

0 commit comments

Comments
 (0)