Skip to content

Commit 417a85a

Browse files
authored
metal: minor q4 optimization and reduce code size (#2248)
* metal: use uint16_t instead of uint8_t. Apple GPU doesn't like uint8_t. For every operation on uint8_t the gpu need to copy the uint8_t to an empty 16 bit register, then it can issue other instructions. For the matrix-vector multiplication kernel only, we observed a 340~350 GB/s memory read speed on M1 Max after this commit, which is very close to the reported hardware limit. * metal: update rms_norm kernel This commit double the speed of rms_norm operations by using 512 threads per threadgroup, combining with SIMD primitives to minimize the need for thread group barriers. * metal: use template to reduce size Revert modifications on block_q4_0 and block_q4_1.
1 parent 294f424 commit 417a85a

File tree

2 files changed

+93
-147
lines changed

2 files changed

+93
-147
lines changed

ggml-metal.m

+2-2
Original file line numberDiff line numberDiff line change
@@ -792,15 +792,15 @@ void ggml_metal_graph_compute(
792792

793793
const float eps = 1e-6f;
794794

795-
const int nth = 256;
795+
const int nth = 512;
796796

797797
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
798798
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
799799
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
800800
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
801801
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
802802
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
803-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
803+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
804804

805805
const int64_t nrows = ggml_nrows(src0);
806806

ggml-metal.metal

+91-145
Original file line numberDiff line numberDiff line change
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331331
threadgroup float * sum [[threadgroup(0)]],
332332
uint tgpig[[threadgroup_position_in_grid]],
333333
uint tpitg[[thread_position_in_threadgroup]],
334+
uint sgitg[[simdgroup_index_in_threadgroup]],
335+
uint tiisg[[thread_index_in_simdgroup]],
334336
uint ntg[[threads_per_threadgroup]]) {
335-
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338+
device const float * x_scalar = (device const float *) x;
339+
float4 sumf=0;
340+
float all_sum=0;
336341

337342
// parallel sum
338-
sum[tpitg] = 0.0f;
339-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340-
sum[tpitg] += x[i00] * x[i00];
343+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
344+
sumf += x[i00] * x[i00];
345+
}
346+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
347+
all_sum = simd_sum(all_sum);
348+
if (tiisg == 0) {
349+
sum[sgitg] = all_sum;
341350
}
342351

343-
// reduce
344352
threadgroup_barrier(mem_flags::mem_threadgroup);
345-
for (uint i = ntg/2; i > 0; i /= 2) {
346-
if (tpitg < i) {
347-
sum[tpitg] += sum[tpitg + i];
348-
}
349-
threadgroup_barrier(mem_flags::mem_threadgroup);
353+
// broadcast, simd group number is ntg / 32
354+
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
355+
if (tpitg < i) {
356+
sum[tpitg] += sum[tpitg + i];
357+
}
350358
}
351-
352-
// broadcast
353359
if (tpitg == 0) {
360+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354361
sum[0] /= ne00;
355362
}
356363

@@ -359,104 +366,102 @@ kernel void kernel_rms_norm(
359366
const float mean = sum[0];
360367
const float scale = 1.0f/sqrt(mean + eps);
361368

362-
device float * y = dst + tgpig*ne00;
363-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
370+
device float * y_scalar = (device float *) y;
371+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364372
y[i00] = x[i00] * scale;
365373
}
374+
if (tpitg == 0) {
375+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376+
}
377+
}
378+
379+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380+
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381+
float d = qb_curr->d;
382+
float4 acc = 0.f;
383+
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384+
for (int i = 0; i < 16; i+=2) {
385+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389+
}
390+
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391+
}
392+
393+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394+
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395+
float d = qb_curr->d;
396+
float m = qb_curr->m;
397+
float4 acc = 0.f;
398+
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399+
for (int i = 0; i < 16; i+=2) {
400+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404+
}
405+
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
366406
}
367407

368408
// putting them in the kernel cause a significant performance penalty
369409
#define N_DST 4 // each SIMD group works on 4 rows
370410
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371411
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372-
kernel void kernel_mul_mat_q4_0_f32(
373-
device const void * src0,
374-
device const float * src1,
375-
device float * dst,
376-
constant int64_t & ne00,
377-
constant int64_t & ne10,
378-
constant int64_t & ne0,
379-
constant int64_t & ne01[[buffer(4)]],
380-
uint2 tgpig[[threadgroup_position_in_grid]],
381-
uint tiisg[[thread_index_in_simdgroup]],
382-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
412+
template<typename block_q_type>
413+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415+
uint2 tgpig, uint tiisg, uint sgitg) {
383416
const int nb = ne00/QK4_0;
384417
const int r0 = tgpig.x;
385418
const int r1 = tgpig.y;
386-
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
387420
device const float * y = (device const float *) src1 + r1*ne10;
388-
block_q4_0 qb_curr, qb_next;
389421
float4 y_curr[8]; // src1 vector cache
390422
float sumf[N_DST]={0.f}, all_sum;
391423
thread float * yl=(thread float *)y_curr;
392424

393-
// bootstrap
394-
qb_curr = x[tiisg];
395425
// each thread in a SIMD group deals with 1 block.
396426
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397-
398427
float sumy = 0;
399428
for (int i = 0; i < QK4_0 / 4; i++) {
400-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
429+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
401430
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
402431
}
403-
sumy *= (-8.f);
404432

405433
for (int row = 0; row < N_DST; row++) {
406-
// prefetch next x block
407-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
408-
409-
// calculate
410-
float d = qb_curr.d;
411-
float acc = sumy;
412-
for (int i = 0; i < 16; i++) {
413-
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
414-
}
415-
sumf[row] += d * acc;
416-
qb_curr = qb_next;
434+
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
417435
}
418436
}
419437

420-
if (nb % N_SIMDWIDTH == 0) {
421-
for (int row = 0; row < N_DST; ++row) {
422-
all_sum = simd_sum(sumf[row]);
423-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
424-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
425-
}
426-
}
427-
} else {
428-
438+
// from now loads two rows every time and 16 blocks per row
439+
int ir = tiisg / (N_SIMDWIDTH / 2);
440+
int ib = tiisg % (N_SIMDWIDTH / 2);
441+
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442+
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
429443
float sumy = 0;
430444
for (int i = 0; i < QK4_0 / 4; i++) {
431-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
445+
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
432446
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
433447
}
434-
sumy *= (-8.f);
435448

436-
for (int row = 0; row < N_DST; row++) {
437-
// prefetch next x block
438-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
439-
440-
// calculate
441-
float d = qb_curr.d;
442-
float acc = sumy;
443-
for (int i = 0; i < 16; i++) {
444-
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
449+
for (int row = 0; row < N_DST; row+=2) {
450+
if (nb_start + ib < nb) {
451+
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
445452
}
446-
if (tiisg < nb % N_SIMDWIDTH) {
447-
sumf[row] += d * acc;
448-
}
449-
qb_curr = qb_next;
453+
}
454+
}
450455

451-
all_sum = simd_sum(sumf[row]);
452-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
453-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
454-
}
456+
for (int row = 0; row < N_DST; ++row) {
457+
all_sum = simd_sum(sumf[row]);
458+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
455460
}
456461
}
457462
}
458463

459-
kernel void kernel_mul_mat_q4_1_f32(
464+
kernel void kernel_mul_mat_q4_0_f32(
460465
device const void * src0,
461466
device const float * src1,
462467
device float * dst,
@@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
467472
uint2 tgpig[[threadgroup_position_in_grid]],
468473
uint tiisg[[thread_index_in_simdgroup]],
469474
uint sgitg[[simdgroup_index_in_threadgroup]]) {
470-
const int nb = ne00/QK4_0;
471-
const int r0 = tgpig.x;
472-
const int r1 = tgpig.y;
473-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
474-
device const float * y = (device const float *) src1 + r1*ne10;
475-
block_q4_1 qb_curr, qb_next;
476-
float4 y_curr[8]; // src1 vector cache
477-
float sumf[N_DST]={0.f}, all_sum;
478-
thread float * yl=(thread float *)y_curr;
479-
480-
// bootstrap
481-
qb_curr = x[tiisg];
482-
// each thread in a SIMD group deals with 1 block.
483-
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
484-
485-
float sumy = 0;
486-
for (int i = 0; i < QK4_0 / 4; i++) {
487-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
488-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
489-
}
490-
491-
for (int row = 0; row < N_DST; row++) {
492-
// prefetch next x block
493-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
494-
495-
// calculate
496-
const float d = qb_curr.d;
497-
const float m = qb_curr.m;
498-
float acc = 0.f;
499-
for (int i = 0; i < 16; i++) {
500-
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
501-
}
502-
sumf[row] += d * acc + m * sumy;
503-
qb_curr = qb_next;
504-
}
505-
}
506-
507-
if (nb % N_SIMDWIDTH == 0) {
508-
for (int row = 0; row < N_DST; ++row) {
509-
all_sum = simd_sum(sumf[row]);
510-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
511-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
512-
}
513-
}
514-
} else {
515-
516-
float sumy = 0;
517-
for (int i = 0; i < QK4_0 / 4; i++) {
518-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
519-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
520-
}
521-
522-
for (int row = 0; row < N_DST; row++) {
523-
// prefetch next x block
524-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
525-
526-
// calculate
527-
const float d = qb_curr.d;
528-
const float m = qb_curr.m;
529-
float acc = 0.f;
530-
for (int i = 0; i < 16; i++) {
531-
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
532-
}
533-
if (tiisg < nb % N_SIMDWIDTH) {
534-
sumf[row] += d * acc + m * sumy;
535-
}
536-
qb_curr = qb_next;
475+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476+
}
537477

538-
all_sum = simd_sum(sumf[row]);
539-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
540-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
541-
}
542-
}
543-
}
478+
kernel void kernel_mul_mat_q4_1_f32(
479+
device const void * src0,
480+
device const float * src1,
481+
device float * dst,
482+
constant int64_t & ne00,
483+
constant int64_t & ne10,
484+
constant int64_t & ne0,
485+
constant int64_t & ne01[[buffer(4)]],
486+
uint2 tgpig[[threadgroup_position_in_grid]],
487+
uint tiisg[[thread_index_in_simdgroup]],
488+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
489+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
544490
}
545491

546492
kernel void kernel_mul_mat_f16_f32(

0 commit comments

Comments
 (0)