Skip to content

Commit 38ec9a2

Browse files
committed
metal: New q4_0 matrix-vector kernel
Prefetch data to improve GPU utilization. ~48% faster for 33B model.
1 parent 20d7740 commit 38ec9a2

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

ggml-metal.m

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
GGML_METAL_DECL_KERNEL(norm);
6363
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
6464
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
65+
GGML_METAL_DECL_KERNEL(mul_mat_vec_q4_0_f32);
6566
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
6667
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
6768
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
@@ -177,6 +178,7 @@ @implementation GGMLMetalClass
177178
GGML_METAL_ADD_KERNEL(norm);
178179
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
179180
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
181+
GGML_METAL_ADD_KERNEL(mul_mat_vec_q4_0_f32);
180182
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
181183
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
182184
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
@@ -660,7 +662,11 @@ void ggml_metal_graph_compute(
660662

661663
nth0 = 8;
662664
nth1 = 8;
663-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
665+
if (ne01 % 8 == 0) {
666+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_vec_q4_0_f32];
667+
} else {
668+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
669+
}
664670
} break;
665671
case GGML_TYPE_Q4_1:
666672
{
@@ -740,8 +746,12 @@ void ggml_metal_graph_compute(
740746
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741747

742748
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
743-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
744-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749+
if (ne01 % 8 == 0) {
750+
[encoder dispatchThreadgroups:MTLSizeMake(ne01/8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751+
} else {
752+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
753+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
754+
}
745755
}
746756
else if (src0t == GGML_TYPE_Q2_K ||
747757
src0t == GGML_TYPE_Q3_K ||

ggml-metal.metal

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,82 @@ kernel void kernel_rms_norm(
365365
}
366366
}
367367

368+
// putting them in the kernel cause a significant performance penalty
369+
#define N_DST 4 // each SIMD group works on 4 rows
370+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372+
kernel void kernel_mul_mat_vec_q4_0_f32(
373+
device const void * src0,
374+
device const float * src1,
375+
device float * dst,
376+
constant int64_t & ne00,
377+
constant int64_t & ne10,
378+
constant int64_t & ne0,
379+
uint2 tgpig[[threadgroup_position_in_grid]],
380+
uint tiisg[[thread_index_in_simdgroup]],
381+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
382+
const int nb = ne00/QK4_0;
383+
const int r0 = tgpig.x;
384+
const int r1 = tgpig.y;
385+
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
386+
device const float * y = (device const float *) src1 + r1*ne10;
387+
block_q4_0 qb_curr, qb_next;
388+
float4 y_curr[8]; // src1 vector cache
389+
float sumf[N_DST]={0.f}, all_sum;
390+
thread float * yl=(thread float *)y_curr;
391+
392+
// bootstrap
393+
qb_curr = x[tiisg];
394+
// each thread in a SIMD group deals with 1 block.
395+
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
396+
397+
for (int i = 0; i < QK4_0 / 4; i++) {
398+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
399+
}
400+
401+
for (int row = 0; row < N_DST; row++) {
402+
// prefetch next x block
403+
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
404+
405+
// calculate
406+
float d = qb_curr.d;
407+
float2 acc = {0.0f, 0.0f};
408+
for (int i = 0; i < 16; i++) {
409+
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
410+
acc[1] += yl[i] + yl[i+16];
411+
}
412+
sumf[row] += d * (acc[0] - 8.f*acc[1]);
413+
qb_curr = qb_next;
414+
}
415+
}
416+
417+
for (int i = 0; i < QK4_0 / 4; i++) {
418+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
419+
}
420+
421+
for (int row = 0; row < N_DST; row++) {
422+
// prefetch next x block
423+
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
424+
425+
// calculate
426+
float d = qb_curr.d;
427+
float2 acc = {0.0f, 0.0f};
428+
for (int i = 0; i < 16; i++) {
429+
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
430+
acc[1] += yl[i] + yl[i+16];
431+
}
432+
if (tiisg < nb % N_SIMDWIDTH) {
433+
sumf[row] += d * (acc[0] - 8.f*acc[1]);
434+
}
435+
qb_curr = qb_next;
436+
437+
all_sum = simd_sum(sumf[row]);
438+
if (tiisg == 0) {
439+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
440+
}
441+
}
442+
}
443+
368444
kernel void kernel_mul_mat_q4_0_f32(
369445
device const void * src0,
370446
device const float * src1,

0 commit comments

Comments
 (0)