Skip to content

Commit e9b66ee

Browse files
ikawrakowKawrakow
andauthored
metal : add Q4_1 implementation (#1785)
23.3 ms / token, so just ~1% slower than q4_0. Achieves 290 GB/s memory throughput. Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 4f0154b commit e9b66ee

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

ggml-metal.m

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@
5050
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5151
GGML_METAL_DECL_KERNEL(get_rows_f16);
5252
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53+
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
5354
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
5455
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5556
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5657
GGML_METAL_DECL_KERNEL(rms_norm);
5758
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5859
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
60+
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
5961
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
6062
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
6163
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
@@ -141,12 +143,14 @@
141143
GGML_METAL_ADD_KERNEL(diag_mask_inf);
142144
GGML_METAL_ADD_KERNEL(get_rows_f16);
143145
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
146+
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
144147
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
145148
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
146149
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
147150
GGML_METAL_ADD_KERNEL(rms_norm);
148151
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
149152
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
153+
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
150154
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
151155
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
152156
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
@@ -545,6 +549,15 @@ void ggml_metal_graph_compute(
545549
nth1 = 8;
546550
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
547551
} break;
552+
case GGML_TYPE_Q4_1:
553+
{
554+
GGML_ASSERT(ne02 == 1);
555+
GGML_ASSERT(ne12 == 1);
556+
557+
nth0 = 8;
558+
nth1 = 8;
559+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
560+
} break;
548561
case GGML_TYPE_Q2_K:
549562
{
550563
GGML_ASSERT(ne02 == 1);
@@ -596,7 +609,7 @@ void ggml_metal_graph_compute(
596609
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
597610
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
598611

599-
if (src0t == GGML_TYPE_Q4_0) {
612+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
600613
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
601614
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
602615
} else if (src0t == GGML_TYPE_Q2_K) {
@@ -623,6 +636,7 @@ void ggml_metal_graph_compute(
623636
switch (src0->type) {
624637
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
625638
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
639+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
626640
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
627641
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
628642
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;

ggml-metal.metal

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ typedef struct {
1111
uint8_t qs[QK4_0 / 2]; // nibbles / quants
1212
} block_q4_0;
1313

14+
#define QK4_1 32
15+
typedef struct {
16+
half d; // delta
17+
half m; // min
18+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
19+
} block_q4_1;
20+
1421
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
1522
const int qk = QK4_0;
1623

@@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
3138
}
3239
}
3340

41+
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42+
const int qk = QK4_1;
43+
44+
assert(k % qk == 0);
45+
46+
const int nb = k / qk;
47+
48+
for (int i = 0; i < nb; i++) {
49+
const half d = x[i].d;
50+
const half m = x[i].m;
51+
52+
for (int j = 0; j < qk/2; ++j) {
53+
const int x0 = (x[i].qs[j] & 0x0F);
54+
const int x1 = (x[i].qs[j] >> 4);
55+
56+
y[i*qk + j + 0 ] = x0*d + m;
57+
y[i*qk + j + qk/2] = x1*d + m;
58+
}
59+
}
60+
}
61+
3462
kernel void kernel_add(
3563
device const float * src0,
3664
device const float * src1,
@@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
212240
(device float *) ((device char *) dst + i*nb1), ne00);
213241
}
214242

243+
kernel void kernel_get_rows_q4_1(
244+
device const void * src0,
245+
device const int * src1,
246+
device float * dst,
247+
constant int64_t & ne00,
248+
constant uint64_t & nb01,
249+
constant uint64_t & nb1,
250+
uint tpig[[thread_position_in_grid]]) {
251+
const int i = tpig;
252+
const int r = ((device int32_t *) src1)[i];
253+
254+
dequantize_row_q4_1(
255+
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
256+
(device float *) ((device char *) dst + i*nb1), ne00);
257+
}
258+
215259
kernel void kernel_rms_norm(
216260
device const void * src0,
217261
device float * dst,
@@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
350394
//}
351395
}
352396

397+
kernel void kernel_mul_mat_q4_1_f32(
398+
device const void * src0,
399+
device const float * src1,
400+
device float * dst,
401+
constant int64_t & ne00,
402+
constant int64_t & ne01,
403+
constant uint64_t & nb00,
404+
constant uint64_t & nb01,
405+
constant uint64_t & nb02,
406+
constant int64_t & ne10,
407+
constant int64_t & ne11,
408+
constant uint64_t & nb10,
409+
constant uint64_t & nb11,
410+
constant uint64_t & nb12,
411+
constant int64_t & ne0,
412+
constant int64_t & ne1,
413+
threadgroup float * sum [[threadgroup(0)]],
414+
uint2 tgpig[[threadgroup_position_in_grid]],
415+
uint2 tpig[[thread_position_in_grid]],
416+
uint2 tpitg[[thread_position_in_threadgroup]],
417+
uint2 tptg[[threads_per_threadgroup]]) {
418+
const int nb = ne00/QK4_1;
419+
420+
const int64_t r0 = tgpig.x;
421+
const int64_t r1 = tgpig.y;
422+
423+
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
424+
device const float * y = (device const float *) src1 + r1*ne10;
425+
426+
const uint nth = tptg.x*tptg.y;
427+
const uint ith = tptg.y*tpitg.x + tpitg.y;
428+
429+
const int ix = tpitg.y/4; // 0 or 1
430+
const int iy = tpitg.y - 4*ix; // 0...3
431+
432+
const int first = 4 * iy;
433+
434+
float sumf = 0;
435+
436+
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
437+
438+
const float d = (float)x[i].d;
439+
const float m = (float)x[i].m;
440+
441+
device const uint8_t * xl = x[i].qs + first;
442+
device const float * yl = y + i * QK4_1 + first;
443+
444+
float2 acc = {0.0f, 0.0f};
445+
446+
for (int j = 0; j < 4; ++j) {
447+
448+
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
449+
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
450+
451+
}
452+
453+
sumf += acc[0] + acc[1];
454+
}
455+
456+
sum[ith] = sumf;
457+
458+
//
459+
// Accumulate the sum from all threads in the threadgroup
460+
//
461+
threadgroup_barrier(mem_flags::mem_threadgroup);
462+
if (ith%4 == 0) {
463+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
464+
}
465+
threadgroup_barrier(mem_flags::mem_threadgroup);
466+
if (ith%16 == 0) {
467+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
468+
}
469+
threadgroup_barrier(mem_flags::mem_threadgroup);
470+
if (ith == 0) {
471+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
472+
dst[r1*ne0 + r0] = sum[0];
473+
}
474+
}
475+
353476
kernel void kernel_mul_mat_f16_f32(
354477
device const char * src0,
355478
device const char * src1,

0 commit comments

Comments
 (0)