@@ -8,14 +8,14 @@ using namespace metal;
8
8
#define QR4_0 2
9
9
typedef struct {
10
10
half d; // delta
11
- uint16_t qs[QK4_0 / 4 ]; // nibbles / quants
11
+ uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
12
12
} block_q4_0;
13
13
14
14
#define QK4_1 32
15
15
typedef struct {
16
16
half d; // delta
17
17
half m; // min
18
- uint16_t qs[QK4_1 / 4 ]; // nibbles / quants
18
+ uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
19
19
} block_q4_1;
20
20
21
21
static void dequantize_row_q4_0 (device const block_q4_0 * x, device float * y, int k) {
@@ -28,16 +28,12 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
28
28
for (int i = 0 ; i < nb; i++) {
29
29
const half d = x[i].d ;
30
30
31
- for (int j = 0 ; j < qk/4 ; ++j) {
32
- const int x0 = (x[i].qs [j] & 0x000F ) - 8 ;
33
- const int x1 = ((x[i].qs [j] & 0x00F0 ) >> 4 ) - 8 ;
34
- const int x2 = ((x[i].qs [j] & 0x0F00 ) >> 8 ) - 8 ;
35
- const int x3 = ((x[i].qs [j] & 0xF000 ) >> 12 ) - 8 ;
31
+ for (int j = 0 ; j < qk/2 ; ++j) {
32
+ const int x0 = (x[i].qs [j] & 0x0F ) - 8 ;
33
+ const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
36
34
37
- y[i*qk + 2 * j + 0 ] = x0*d;
38
- y[i*qk + 2 * j + qk/2 ] = x1*d;
39
- y[i*qk + 2 * j + 1 ] = x2*d;
40
- y[i*qk + 2 * j + 1 + qk/2 ] = x3*d;
35
+ y[i*qk + j + 0 ] = x0*d;
36
+ y[i*qk + j + qk/2 ] = x1*d;
41
37
}
42
38
}
43
39
}
@@ -53,16 +49,12 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
53
49
const half d = x[i].d ;
54
50
const half m = x[i].m ;
55
51
56
- for (int j = 0 ; j < qk/4 ; ++j) {
57
- const int x0 = (x[i].qs [j] & 0x000F );
58
- const int x1 = ((x[i].qs [j] & 0x00F0 ) >> 4 );
59
- const int x2 = ((x[i].qs [j] & 0x0F00 ) >> 8 );
60
- const int x3 = ((x[i].qs [j] & 0xF000 ) >> 12 );
52
+ for (int j = 0 ; j < qk/2 ; ++j) {
53
+ const int x0 = (x[i].qs [j] & 0x0F );
54
+ const int x1 = (x[i].qs [j] >> 4 );
61
55
62
- y[i*qk + 2 * j + 0 ] = x0*d + m;
63
- y[i*qk + 2 * j + qk/2 ] = x1*d + m;
64
- y[i*qk + 2 * j + 1 ] = x2*d + m;
65
- y[i*qk + 2 * j + 1 + qk/2 ] = x3*d + m;
56
+ y[i*qk + j + 0 ] = x0*d + m;
57
+ y[i*qk + j + qk/2 ] = x1*d + m;
66
58
}
67
59
}
68
60
}
@@ -384,107 +376,92 @@ kernel void kernel_rms_norm(
384
376
}
385
377
}
386
378
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d ;
382
+ float4 acc = 0 .f ;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1 );
384
+ for (int i = 0 ; i < 16 ; i+=2 ) {
385
+ acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
386
+ acc[1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
387
+ acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
388
+ acc[3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
389
+ }
390
+ return d * (sumy * -8 .f + acc[0 ] + acc[1 ]/16 .f + acc[2 ]/256 .f + acc[3 ]/4096 .f );
391
+ }
392
+
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d ;
396
+ float m = qb_curr->m ;
397
+ float4 acc = 0 .f ;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2 );
399
+ for (int i = 0 ; i < 16 ; i+=2 ) {
400
+ acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
401
+ acc[1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
402
+ acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
403
+ acc[3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
404
+ }
405
+ return d * (acc[0 ] + acc[1 ]/16 .f + acc[2 ]/256 .f + acc[3 ]/4096 .f ) + sumy * m;
406
+ }
407
+
387
408
// putting them in the kernel cause a significant performance penalty
388
409
#define N_DST 4 // each SIMD group works on 4 rows
389
410
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
390
411
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
391
- kernel void kernel_mul_mat_q4_0_f32 (
392
- device const void * src0,
393
- device const float * src1,
394
- device float * dst,
395
- constant int64_t & ne00,
396
- constant int64_t & ne10,
397
- constant int64_t & ne0,
398
- constant int64_t & ne01[[buffer(4 )]],
399
- uint2 tgpig[[threadgroup_position_in_grid]],
400
- uint tiisg[[thread_index_in_simdgroup]],
401
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
412
+ template <typename block_q_type>
413
+ void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
402
416
const int nb = ne00/QK4_0;
403
417
const int r0 = tgpig.x ;
404
418
const int r1 = tgpig.y ;
405
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
406
420
device const float * y = (device const float *) src1 + r1*ne10;
407
- block_q4_0 qb_curr, qb_next;
408
421
float4 y_curr[8 ]; // src1 vector cache
409
422
float sumf[N_DST]={0 .f }, all_sum;
410
423
thread float * yl=(thread float *)y_curr;
411
424
412
- // bootstrap
413
- qb_curr = x[tiisg];
414
425
// each thread in a SIMD group deals with 1 block.
415
426
for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
416
-
417
427
float sumy = 0 ;
418
428
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i) );
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i );
420
430
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
421
431
}
422
- sumy *= (-8 .f );
423
- // we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
424
- for (int i = 0 ; i < 32 ; i++) {
425
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
426
- }
427
432
428
433
for (int row = 0 ; row < N_DST; row++) {
429
- // prefetch next x block
430
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
431
-
432
- // calculate
433
- float d = qb_curr.d ;
434
- float acc = sumy;
435
- for (int i = 0 ; i < 16 ; i+=2 ) {
436
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
437
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
438
- }
439
- sumf[row] += d * acc;
440
- qb_curr = qb_next;
434
+ sumf[row] += block_q_n_dot_y (x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
441
435
}
442
436
}
443
437
444
- if (nb % N_SIMDWIDTH == 0 ) {
445
- for (int row = 0 ; row < N_DST; ++row) {
446
- all_sum = simd_sum (sumf[row]);
447
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
448
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
449
- }
450
- }
451
- } else {
452
-
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2 );
440
+ int ib = tiisg % (N_SIMDWIDTH / 2 );
441
+ for (int ind = 0 ; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1 )/(N_SIMDWIDTH / 2 ); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2 ); // where the left blocks start
453
443
float sumy = 0 ;
454
444
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
455
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH ) * QK4_0) + 4 * i) );
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib ) * QK4_0) + i );
456
446
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
457
447
}
458
- sumy *= (-8 .f );
459
- for (int i = 0 ; i < 32 ; i++) {
460
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
461
- }
462
448
463
- for (int row = 0 ; row < N_DST; row++) {
464
- // prefetch next x block
465
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
466
-
467
- // calculate
468
- float d = qb_curr.d ;
469
- float acc = sumy;
470
- for (int i = 0 ; i < 16 ; i+=2 ) {
471
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
472
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
473
- }
474
- if (tiisg < nb % N_SIMDWIDTH) {
475
- sumf[row] += d * acc;
449
+ for (int row = 0 ; row < N_DST; row+=2 ) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y (x + (nb_start + ib + (row + ir) * nb), sumy, yl);
476
452
}
477
- qb_curr = qb_next;
453
+ }
454
+ }
478
455
479
- all_sum = simd_sum (sumf[ row]);
480
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
481
- dst[r1*ne0 + ( r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
482
- }
456
+ for ( int row = 0 ; row < N_DST; ++ row) {
457
+ all_sum = simd_sum (sumf[ row]);
458
+ if (tiisg == 0 && (( r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
483
460
}
484
461
}
485
462
}
486
463
487
- kernel void kernel_mul_mat_q4_1_f32 (
464
+ kernel void kernel_mul_mat_q4_0_f32 (
488
465
device const void * src0,
489
466
device const float * src1,
490
467
device float * dst,
@@ -495,89 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
495
472
uint2 tgpig[[threadgroup_position_in_grid]],
496
473
uint tiisg[[thread_index_in_simdgroup]],
497
474
uint sgitg[[simdgroup_index_in_threadgroup]]) {
498
- const int nb = ne00/QK4_0;
499
- const int r0 = tgpig.x ;
500
- const int r1 = tgpig.y ;
501
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
502
- device const float * y = (device const float *) src1 + r1*ne10;
503
- block_q4_1 qb_curr, qb_next;
504
- float4 y_curr[8 ]; // src1 vector cache
505
- float sumf[N_DST]={0 .f }, all_sum;
506
- thread float * yl=(thread float *)y_curr;
507
-
508
- // bootstrap
509
- qb_curr = x[tiisg];
510
- // each thread in a SIMD group deals with 1 block.
511
- for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
512
-
513
- float sumy = 0 ;
514
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
515
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
516
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
517
- }
518
- // we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
519
- for (int i = 0 ; i < 32 ; i++) {
520
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
521
- }
522
-
523
- for (int row = 0 ; row < N_DST; row++) {
524
- // prefetch next x block
525
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
526
-
527
- // calculate
528
- const float d = qb_curr.d ;
529
- const float m = qb_curr.m ;
530
- float acc = 0 .f ;
531
- for (int i = 0 ; i < 16 ; i+=2 ) {
532
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
533
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
534
- }
535
- sumf[row] += d * acc + m * sumy;
536
- qb_curr = qb_next;
537
- }
538
- }
539
-
540
- if (nb % N_SIMDWIDTH == 0 ) {
541
- for (int row = 0 ; row < N_DST; ++row) {
542
- all_sum = simd_sum (sumf[row]);
543
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
544
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
545
- }
546
- }
547
- } else {
548
-
549
- float sumy = 0 ;
550
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
551
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
552
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
553
- }
554
- for (int i = 0 ; i < 32 ; i++) {
555
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
556
- }
557
-
558
- for (int row = 0 ; row < N_DST; row++) {
559
- // prefetch next x block
560
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
561
-
562
- // calculate
563
- const float d = qb_curr.d ;
564
- const float m = qb_curr.m ;
565
- float acc = 0 .f ;
566
- for (int i = 0 ; i < 16 ; i+=2 ) {
567
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
568
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
569
- }
570
- if (tiisg < nb % N_SIMDWIDTH) {
571
- sumf[row] += d * acc + m * sumy;
572
- }
573
- qb_curr = qb_next;
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
574
477
575
- all_sum = simd_sum (sumf[row]);
576
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
577
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
578
- }
579
- }
580
- }
478
+ kernel void kernel_mul_mat_q4_1_f32 (
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4 )]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
581
490
}
582
491
583
492
kernel void kernel_mul_mat_f16_f32 (
0 commit comments