Skip to content

Commit f3f2e8e

Browse files
committed
metal: use template to reduce size
Revert modifications on block_q4_0 and block_q4_1.
1 parent 4088df1 commit f3f2e8e

File tree

1 file changed

+78
-169
lines changed

1 file changed

+78
-169
lines changed

ggml-metal.metal

Lines changed: 78 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ using namespace metal;
88
#define QR4_0 2
99
typedef struct {
1010
half d; // delta
11-
uint16_t qs[QK4_0 / 4]; // nibbles / quants
11+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
1212
} block_q4_0;
1313

1414
#define QK4_1 32
1515
typedef struct {
1616
half d; // delta
1717
half m; // min
18-
uint16_t qs[QK4_1 / 4]; // nibbles / quants
18+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1919
} block_q4_1;
2020

2121
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
@@ -28,16 +28,12 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
2828
for (int i = 0; i < nb; i++) {
2929
const half d = x[i].d;
3030

31-
for (int j = 0; j < qk/4; ++j) {
32-
const int x0 = (x[i].qs[j] & 0x000F) - 8;
33-
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8;
34-
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8;
35-
const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8;
31+
for (int j = 0; j < qk/2; ++j) {
32+
const int x0 = (x[i].qs[j] & 0x0F) - 8;
33+
const int x1 = (x[i].qs[j] >> 4) - 8;
3634

37-
y[i*qk + 2 * j + 0 ] = x0*d;
38-
y[i*qk + 2 * j + qk/2 ] = x1*d;
39-
y[i*qk + 2 * j + 1 ] = x2*d;
40-
y[i*qk + 2 * j + 1 + qk/2] = x3*d;
35+
y[i*qk + j + 0 ] = x0*d;
36+
y[i*qk + j + qk/2] = x1*d;
4137
}
4238
}
4339
}
@@ -53,16 +49,12 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
5349
const half d = x[i].d;
5450
const half m = x[i].m;
5551

56-
for (int j = 0; j < qk/4; ++j) {
57-
const int x0 = (x[i].qs[j] & 0x000F);
58-
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4);
59-
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8);
60-
const int x3 = ((x[i].qs[j] & 0xF000) >> 12);
52+
for (int j = 0; j < qk/2; ++j) {
53+
const int x0 = (x[i].qs[j] & 0x0F);
54+
const int x1 = (x[i].qs[j] >> 4);
6155

62-
y[i*qk + 2 * j + 0 ] = x0*d + m;
63-
y[i*qk + 2 * j + qk/2 ] = x1*d + m;
64-
y[i*qk + 2 * j + 1 ] = x2*d + m;
65-
y[i*qk + 2 * j + 1 + qk/2] = x3*d + m;
56+
y[i*qk + j + 0 ] = x0*d + m;
57+
y[i*qk + j + qk/2] = x1*d + m;
6658
}
6759
}
6860
}
@@ -384,107 +376,92 @@ kernel void kernel_rms_norm(
384376
}
385377
}
386378

379+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380+
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381+
float d = qb_curr->d;
382+
float4 acc = 0.f;
383+
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384+
for (int i = 0; i < 16; i+=2) {
385+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389+
}
390+
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391+
}
392+
393+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394+
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395+
float d = qb_curr->d;
396+
float m = qb_curr->m;
397+
float4 acc = 0.f;
398+
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399+
for (int i = 0; i < 16; i+=2) {
400+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404+
}
405+
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
406+
}
407+
387408
// putting them in the kernel cause a significant performance penalty
388409
#define N_DST 4 // each SIMD group works on 4 rows
389410
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
390411
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
391-
kernel void kernel_mul_mat_q4_0_f32(
392-
device const void * src0,
393-
device const float * src1,
394-
device float * dst,
395-
constant int64_t & ne00,
396-
constant int64_t & ne10,
397-
constant int64_t & ne0,
398-
constant int64_t & ne01[[buffer(4)]],
399-
uint2 tgpig[[threadgroup_position_in_grid]],
400-
uint tiisg[[thread_index_in_simdgroup]],
401-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
412+
template<typename block_q_type>
413+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415+
uint2 tgpig, uint tiisg, uint sgitg) {
402416
const int nb = ne00/QK4_0;
403417
const int r0 = tgpig.x;
404418
const int r1 = tgpig.y;
405-
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
406420
device const float * y = (device const float *) src1 + r1*ne10;
407-
block_q4_0 qb_curr, qb_next;
408421
float4 y_curr[8]; // src1 vector cache
409422
float sumf[N_DST]={0.f}, all_sum;
410423
thread float * yl=(thread float *)y_curr;
411424

412-
// bootstrap
413-
qb_curr = x[tiisg];
414425
// each thread in a SIMD group deals with 1 block.
415426
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
416-
417427
float sumy = 0;
418428
for (int i = 0; i < QK4_0 / 4; i++) {
419-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
429+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
420430
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
421431
}
422-
sumy *= (-8.f);
423-
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
424-
for (int i = 0; i < 32; i++) {
425-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
426-
}
427432

428433
for (int row = 0; row < N_DST; row++) {
429-
// prefetch next x block
430-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
431-
432-
// calculate
433-
float d = qb_curr.d;
434-
float acc = sumy;
435-
for (int i = 0; i < 16; i+=2) {
436-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
437-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
438-
}
439-
sumf[row] += d * acc;
440-
qb_curr = qb_next;
434+
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
441435
}
442436
}
443437

444-
if (nb % N_SIMDWIDTH == 0) {
445-
for (int row = 0; row < N_DST; ++row) {
446-
all_sum = simd_sum(sumf[row]);
447-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
448-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
449-
}
450-
}
451-
} else {
452-
438+
// from now loads two rows every time and 16 blocks per row
439+
int ir = tiisg / (N_SIMDWIDTH / 2);
440+
int ib = tiisg % (N_SIMDWIDTH / 2);
441+
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442+
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
453443
float sumy = 0;
454444
for (int i = 0; i < QK4_0 / 4; i++) {
455-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
445+
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
456446
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
457447
}
458-
sumy *= (-8.f);
459-
for (int i = 0; i < 32; i++) {
460-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
461-
}
462448

463-
for (int row = 0; row < N_DST; row++) {
464-
// prefetch next x block
465-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
466-
467-
// calculate
468-
float d = qb_curr.d;
469-
float acc = sumy;
470-
for (int i = 0; i < 16; i+=2) {
471-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
472-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
473-
}
474-
if (tiisg < nb % N_SIMDWIDTH) {
475-
sumf[row] += d * acc;
449+
for (int row = 0; row < N_DST; row+=2) {
450+
if (nb_start + ib < nb) {
451+
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
476452
}
477-
qb_curr = qb_next;
453+
}
454+
}
478455

479-
all_sum = simd_sum(sumf[row]);
480-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
481-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
482-
}
456+
for (int row = 0; row < N_DST; ++row) {
457+
all_sum = simd_sum(sumf[row]);
458+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
483460
}
484461
}
485462
}
486463

487-
kernel void kernel_mul_mat_q4_1_f32(
464+
kernel void kernel_mul_mat_q4_0_f32(
488465
device const void * src0,
489466
device const float * src1,
490467
device float * dst,
@@ -495,89 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
495472
uint2 tgpig[[threadgroup_position_in_grid]],
496473
uint tiisg[[thread_index_in_simdgroup]],
497474
uint sgitg[[simdgroup_index_in_threadgroup]]) {
498-
const int nb = ne00/QK4_0;
499-
const int r0 = tgpig.x;
500-
const int r1 = tgpig.y;
501-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
502-
device const float * y = (device const float *) src1 + r1*ne10;
503-
block_q4_1 qb_curr, qb_next;
504-
float4 y_curr[8]; // src1 vector cache
505-
float sumf[N_DST]={0.f}, all_sum;
506-
thread float * yl=(thread float *)y_curr;
507-
508-
// bootstrap
509-
qb_curr = x[tiisg];
510-
// each thread in a SIMD group deals with 1 block.
511-
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
512-
513-
float sumy = 0;
514-
for (int i = 0; i < QK4_0 / 4; i++) {
515-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
516-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
517-
}
518-
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
519-
for (int i = 0; i < 32; i++) {
520-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
521-
}
522-
523-
for (int row = 0; row < N_DST; row++) {
524-
// prefetch next x block
525-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
526-
527-
// calculate
528-
const float d = qb_curr.d;
529-
const float m = qb_curr.m;
530-
float acc = 0.f;
531-
for (int i = 0; i < 16; i+=2) {
532-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
533-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
534-
}
535-
sumf[row] += d * acc + m * sumy;
536-
qb_curr = qb_next;
537-
}
538-
}
539-
540-
if (nb % N_SIMDWIDTH == 0) {
541-
for (int row = 0; row < N_DST; ++row) {
542-
all_sum = simd_sum(sumf[row]);
543-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
544-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
545-
}
546-
}
547-
} else {
548-
549-
float sumy = 0;
550-
for (int i = 0; i < QK4_0 / 4; i++) {
551-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
552-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
553-
}
554-
for (int i = 0; i < 32; i++) {
555-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
556-
}
557-
558-
for (int row = 0; row < N_DST; row++) {
559-
// prefetch next x block
560-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
561-
562-
// calculate
563-
const float d = qb_curr.d;
564-
const float m = qb_curr.m;
565-
float acc = 0.f;
566-
for (int i = 0; i < 16; i+=2) {
567-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
568-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
569-
}
570-
if (tiisg < nb % N_SIMDWIDTH) {
571-
sumf[row] += d * acc + m * sumy;
572-
}
573-
qb_curr = qb_next;
475+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476+
}
574477

575-
all_sum = simd_sum(sumf[row]);
576-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
577-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
578-
}
579-
}
580-
}
478+
kernel void kernel_mul_mat_q4_1_f32(
479+
device const void * src0,
480+
device const float * src1,
481+
device float * dst,
482+
constant int64_t & ne00,
483+
constant int64_t & ne10,
484+
constant int64_t & ne0,
485+
constant int64_t & ne01[[buffer(4)]],
486+
uint2 tgpig[[threadgroup_position_in_grid]],
487+
uint tiisg[[thread_index_in_simdgroup]],
488+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
489+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
581490
}
582491

583492
kernel void kernel_mul_mat_f16_f32(

0 commit comments

Comments
 (0)