@@ -11,6 +11,13 @@ typedef struct {
11
11
uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
12
12
} block_q4_0;
13
13
14
+ #define QK4_1 32
15
+ typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
+ uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
19
+ } block_q4_1;
20
+
14
21
static void dequantize_row_q4_0 (device const block_q4_0 * x, device float * y, int k) {
15
22
const int qk = QK4_0;
16
23
@@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
31
38
}
32
39
}
33
40
41
+ static void dequantize_row_q4_1 (device const block_q4_1 * x, device float * y, int k) {
42
+ const int qk = QK4_1;
43
+
44
+ assert (k % qk == 0 );
45
+
46
+ const int nb = k / qk;
47
+
48
+ for (int i = 0 ; i < nb; i++) {
49
+ const half d = x[i].d ;
50
+ const half m = x[i].m ;
51
+
52
+ for (int j = 0 ; j < qk/2 ; ++j) {
53
+ const int x0 = (x[i].qs [j] & 0x0F );
54
+ const int x1 = (x[i].qs [j] >> 4 );
55
+
56
+ y[i*qk + j + 0 ] = x0*d + m;
57
+ y[i*qk + j + qk/2 ] = x1*d + m;
58
+ }
59
+ }
60
+ }
61
+
34
62
kernel void kernel_add (
35
63
device const float * src0,
36
64
device const float * src1,
@@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
212
240
(device float *) ((device char *) dst + i*nb1), ne00);
213
241
}
214
242
243
+ kernel void kernel_get_rows_q4_1 (
244
+ device const void * src0,
245
+ device const int * src1,
246
+ device float * dst,
247
+ constant int64_t & ne00,
248
+ constant uint64_t & nb01,
249
+ constant uint64_t & nb1,
250
+ uint tpig[[thread_position_in_grid]]) {
251
+ const int i = tpig;
252
+ const int r = ((device int32_t *) src1)[i];
253
+
254
+ dequantize_row_q4_1 (
255
+ (device const block_q4_1 *) ((device char *) src0 + r*nb01),
256
+ (device float *) ((device char *) dst + i*nb1), ne00);
257
+ }
258
+
215
259
kernel void kernel_rms_norm (
216
260
device const void * src0,
217
261
device float * dst,
@@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
350
394
// }
351
395
}
352
396
397
+ kernel void kernel_mul_mat_q4_1_f32 (
398
+ device const void * src0,
399
+ device const float * src1,
400
+ device float * dst,
401
+ constant int64_t & ne00,
402
+ constant int64_t & ne01,
403
+ constant uint64_t & nb00,
404
+ constant uint64_t & nb01,
405
+ constant uint64_t & nb02,
406
+ constant int64_t & ne10,
407
+ constant int64_t & ne11,
408
+ constant uint64_t & nb10,
409
+ constant uint64_t & nb11,
410
+ constant uint64_t & nb12,
411
+ constant int64_t & ne0,
412
+ constant int64_t & ne1,
413
+ threadgroup float * sum [[threadgroup(0 )]],
414
+ uint2 tgpig[[threadgroup_position_in_grid]],
415
+ uint2 tpig[[thread_position_in_grid]],
416
+ uint2 tpitg[[thread_position_in_threadgroup]],
417
+ uint2 tptg[[threads_per_threadgroup]]) {
418
+ const int nb = ne00/QK4_1;
419
+
420
+ const int64_t r0 = tgpig.x ;
421
+ const int64_t r1 = tgpig.y ;
422
+
423
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
424
+ device const float * y = (device const float *) src1 + r1*ne10;
425
+
426
+ const uint nth = tptg.x *tptg.y ;
427
+ const uint ith = tptg.y *tpitg.x + tpitg.y ;
428
+
429
+ const int ix = tpitg.y /4 ; // 0 or 1
430
+ const int iy = tpitg.y - 4 *ix; // 0...3
431
+
432
+ const int first = 4 * iy;
433
+
434
+ float sumf = 0 ;
435
+
436
+ for (int i = 2 *tpitg.x + ix; i < nb; i += 2 *tptg.x ) {
437
+
438
+ const float d = (float )x[i].d ;
439
+ const float m = (float )x[i].m ;
440
+
441
+ device const uint8_t * xl = x[i].qs + first;
442
+ device const float * yl = y + i * QK4_1 + first;
443
+
444
+ float2 acc = {0 .0f , 0 .0f };
445
+
446
+ for (int j = 0 ; j < 4 ; ++j) {
447
+
448
+ acc[0 ] += yl[j+ 0 ] * (d * (xl[j] & 0xF ) + m);
449
+ acc[1 ] += yl[j+16 ] * (d * (xl[j] >> 4 ) + m);
450
+
451
+ }
452
+
453
+ sumf += acc[0 ] + acc[1 ];
454
+ }
455
+
456
+ sum[ith] = sumf;
457
+
458
+ //
459
+ // Accumulate the sum from all threads in the threadgroup
460
+ //
461
+ threadgroup_barrier (mem_flags::mem_threadgroup);
462
+ if (ith%4 == 0 ) {
463
+ for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
464
+ }
465
+ threadgroup_barrier (mem_flags::mem_threadgroup);
466
+ if (ith%16 == 0 ) {
467
+ for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
468
+ }
469
+ threadgroup_barrier (mem_flags::mem_threadgroup);
470
+ if (ith == 0 ) {
471
+ for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
472
+ dst[r1*ne0 + r0] = sum[0 ];
473
+ }
474
+ }
475
+
353
476
kernel void kernel_mul_mat_f16_f32 (
354
477
device const char * src0,
355
478
device const char * src1,
0 commit comments