Skip to content

Commit 9a08eaf

Browse files
ikawrakowKawrakow
andauthored
Another speed gain for Q4_0 and Q4_1 on Metal (#2375)
* Another speed gain for Q4_0 and Q4_1 on Metal * Have N_DST, etc., be template parameters --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 129d844 commit 9a08eaf

File tree

1 file changed

+57
-54
lines changed

1 file changed

+57
-54
lines changed

ggml-metal.metal

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
387387
}
388388
}
389389

390-
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
391-
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
390+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391+
// il indicates where the q4 quants begin (0 or QK4_0/4)
392+
// we assume that the yl's have been multiplied with the appropriate scale factor
393+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394+
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
392395
float d = qb_curr->d;
393-
float4 acc = 0.f;
394-
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
395-
for (int i = 0; i < 16; i+=2) {
396-
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
397-
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
398-
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
399-
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
396+
float2 acc = 0.f;
397+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
398+
for (int i = 0; i < 8; i+=2) {
399+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
400+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
401+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
402+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
400403
}
401-
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
404+
return d * (sumy * -8.f + acc[0] + acc[1]);
402405
}
403406

404-
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
405-
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
407+
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408+
// il indicates where the q4 quants begin (0 or QK4_0/4)
409+
// we assume that the yl's have been multiplied with the appropriate scale factor
410+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411+
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
406412
float d = qb_curr->d;
407413
float m = qb_curr->m;
408-
float4 acc = 0.f;
409-
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
410-
for (int i = 0; i < 16; i+=2) {
411-
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
412-
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
413-
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
414-
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
414+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
415+
float2 acc = 0.f;
416+
for (int i = 0; i < 8; i+=2) {
417+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
418+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
419+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
420+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
415421
}
416-
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
422+
return d * (acc[0] + acc[1]) + sumy * m;
417423
}
418424

419425
// putting them in the kernel cause a significant performance penalty
420426
#define N_DST 4 // each SIMD group works on 4 rows
421427
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
422428
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
423-
template<typename block_q_type>
429+
//Note: This is a template, but strictly speaking it only applies to
430+
// quantizations where the block size is 32. It also does not
431+
// giard against the number of rows not being divisible by
432+
// N_DST, so this is another explicit assumption of the implementation.
433+
template<typename block_q_type, int nr, int nsg, int nw>
424434
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
425435
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
426436
uint2 tgpig, uint tiisg, uint sgitg) {
427437
const int nb = ne00/QK4_0;
428438
const int r0 = tgpig.x;
429439
const int r1 = tgpig.y;
430-
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440+
const int first_row = (r0 * nsg + sgitg) * nr;
441+
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
431442
device const float * y = (device const float *) src1 + r1*ne10;
432-
float4 y_curr[8]; // src1 vector cache
433-
float sumf[N_DST]={0.f}, all_sum;
434-
thread float * yl=(thread float *)y_curr;
443+
float yl[16]; // src1 vector cache
444+
float sumf[nr]={0.f};
435445

436-
// each thread in a SIMD group deals with 1 block.
437-
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
438-
float sumy = 0;
439-
for (int i = 0; i < QK4_0 / 4; i++) {
440-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
441-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
442-
}
446+
const int ix = tiisg/2;
447+
const int il = 8*(tiisg%2);
443448

444-
for (int row = 0; row < N_DST; row++) {
445-
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
446-
}
447-
}
449+
device const float * yb = y + ix * QK4_0 + il;
448450

449-
// from now loads two rows every time and 16 blocks per row
450-
int ir = tiisg / (N_SIMDWIDTH / 2);
451-
int ib = tiisg % (N_SIMDWIDTH / 2);
452-
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
453-
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
451+
// each thread in a SIMD group deals with half a block.
452+
for (int ib = ix; ib < nb; ib += nw/2) {
454453
float sumy = 0;
455-
for (int i = 0; i < QK4_0 / 4; i++) {
456-
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
457-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
454+
for (int i = 0; i < 8; i += 2) {
455+
sumy += yb[i] + yb[i+1];
456+
yl[i+0] = yb[i+ 0];
457+
yl[i+1] = yb[i+ 1]/256.f;
458+
sumy += yb[i+16] + yb[i+17];
459+
yl[i+8] = yb[i+16]/16.f;
460+
yl[i+9] = yb[i+17]/4096.f;
458461
}
459462

460-
for (int row = 0; row < N_DST; row+=2) {
461-
if (nb_start + ib < nb) {
462-
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
463-
}
463+
for (int row = 0; row < nr; row++) {
464+
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
464465
}
466+
467+
yb += QK4_0 * 16;
465468
}
466469

467-
for (int row = 0; row < N_DST; ++row) {
468-
all_sum = simd_sum(sumf[row]);
469-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
470-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
470+
for (int row = 0; row < nr; ++row) {
471+
const float tot = simd_sum(sumf[row]);
472+
if (tiisg == 0 && first_row + row < ne01) {
473+
dst[r1*ne0 + first_row + row] = tot;
471474
}
472475
}
473476
}
@@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
483486
uint2 tgpig[[threadgroup_position_in_grid]],
484487
uint tiisg[[thread_index_in_simdgroup]],
485488
uint sgitg[[simdgroup_index_in_threadgroup]]) {
486-
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
489+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
487490
}
488491

489492
kernel void kernel_mul_mat_q4_1_f32(
@@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
497500
uint2 tgpig[[threadgroup_position_in_grid]],
498501
uint tiisg[[thread_index_in_simdgroup]],
499502
uint sgitg[[simdgroup_index_in_threadgroup]]) {
500-
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
501504
}
502505

503506
kernel void kernel_mul_mat_f16_f32(

0 commit comments

Comments
 (0)