@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331
331
threadgroup float * sum [[threadgroup(0 )]],
332
332
uint tgpig[[threadgroup_position_in_grid]],
333
333
uint tpitg[[thread_position_in_threadgroup]],
334
+ uint sgitg[[simdgroup_index_in_threadgroup]],
335
+ uint tiisg[[thread_index_in_simdgroup]],
334
336
uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338
+ device const float * x_scalar = (device const float *) x;
339
+ float4 sumf=0 ;
340
+ float all_sum=0 ;
336
341
337
342
// parallel sum
338
- sum[tpitg] = 0 .0f ;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
343
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
344
+ sumf += x[i00] * x[i00];
345
+ }
346
+ all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
347
+ all_sum = simd_sum (all_sum);
348
+ if (tiisg == 0 ) {
349
+ sum[sgitg] = all_sum;
341
350
}
342
351
343
- // reduce
344
352
threadgroup_barrier (mem_flags::mem_threadgroup);
345
- for ( uint i = ntg/ 2 ; i > 0 ; i /= 2 ) {
346
- if (tpitg < i ) {
347
- sum[ tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier (mem_flags::mem_threadgroup);
353
+ // broadcast, simd group number is ntg / 32
354
+ for ( int i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
355
+ if ( tpitg < i) {
356
+ sum[tpitg] += sum[tpitg + i];
357
+ }
350
358
}
351
-
352
- // broadcast
353
359
if (tpitg == 0 ) {
360
+ for (int i = 4 * (ne00 / 4 ); i < ne00; i++) {sum[0 ] += x_scalar[i];}
354
361
sum[0 ] /= ne00;
355
362
}
356
363
@@ -359,104 +366,102 @@ kernel void kernel_rms_norm(
359
366
const float mean = sum[0 ];
360
367
const float scale = 1 .0f /sqrt (mean + eps);
361
368
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
370
+ device float * y_scalar = (device float *) y;
371
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
364
372
y[i00] = x[i00] * scale;
365
373
}
374
+ if (tpitg == 0 ) {
375
+ for (int i00 = 4 * (ne00 / 4 ); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376
+ }
377
+ }
378
+
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d ;
382
+ float4 acc = 0 .f ;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1 );
384
+ for (int i = 0 ; i < 16 ; i+=2 ) {
385
+ acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
386
+ acc[1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
387
+ acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
388
+ acc[3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
389
+ }
390
+ return d * (sumy * -8 .f + acc[0 ] + acc[1 ]/16 .f + acc[2 ]/256 .f + acc[3 ]/4096 .f );
391
+ }
392
+
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d ;
396
+ float m = qb_curr->m ;
397
+ float4 acc = 0 .f ;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2 );
399
+ for (int i = 0 ; i < 16 ; i+=2 ) {
400
+ acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
401
+ acc[1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
402
+ acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
403
+ acc[3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
404
+ }
405
+ return d * (acc[0 ] + acc[1 ]/16 .f + acc[2 ]/256 .f + acc[3 ]/4096 .f ) + sumy * m;
366
406
}
367
407
368
408
// putting them in the kernel cause a significant performance penalty
369
409
#define N_DST 4 // each SIMD group works on 4 rows
370
410
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
411
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
- kernel void kernel_mul_mat_q4_0_f32 (
373
- device const void * src0,
374
- device const float * src1,
375
- device float * dst,
376
- constant int64_t & ne00,
377
- constant int64_t & ne10,
378
- constant int64_t & ne0,
379
- constant int64_t & ne01[[buffer(4 )]],
380
- uint2 tgpig[[threadgroup_position_in_grid]],
381
- uint tiisg[[thread_index_in_simdgroup]],
382
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
412
+ template <typename block_q_type>
413
+ void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
383
416
const int nb = ne00/QK4_0;
384
417
const int r0 = tgpig.x ;
385
418
const int r1 = tgpig.y ;
386
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
387
420
device const float * y = (device const float *) src1 + r1*ne10;
388
- block_q4_0 qb_curr, qb_next;
389
421
float4 y_curr[8 ]; // src1 vector cache
390
422
float sumf[N_DST]={0 .f }, all_sum;
391
423
thread float * yl=(thread float *)y_curr;
392
424
393
- // bootstrap
394
- qb_curr = x[tiisg];
395
425
// each thread in a SIMD group deals with 1 block.
396
426
for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
397
-
398
427
float sumy = 0 ;
399
428
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
400
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i) );
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i );
401
430
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
402
431
}
403
- sumy *= (-8 .f );
404
432
405
433
for (int row = 0 ; row < N_DST; row++) {
406
- // prefetch next x block
407
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
408
-
409
- // calculate
410
- float d = qb_curr.d ;
411
- float acc = sumy;
412
- for (int i = 0 ; i < 16 ; i++) {
413
- acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
414
- }
415
- sumf[row] += d * acc;
416
- qb_curr = qb_next;
434
+ sumf[row] += block_q_n_dot_y (x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
417
435
}
418
436
}
419
437
420
- if (nb % N_SIMDWIDTH == 0 ) {
421
- for (int row = 0 ; row < N_DST; ++row) {
422
- all_sum = simd_sum (sumf[row]);
423
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
424
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
425
- }
426
- }
427
- } else {
428
-
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2 );
440
+ int ib = tiisg % (N_SIMDWIDTH / 2 );
441
+ for (int ind = 0 ; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1 )/(N_SIMDWIDTH / 2 ); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2 ); // where the left blocks start
429
443
float sumy = 0 ;
430
444
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
431
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH ) * QK4_0) + 4 * i) );
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib ) * QK4_0) + i );
432
446
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
433
447
}
434
- sumy *= (-8 .f );
435
448
436
- for (int row = 0 ; row < N_DST; row++) {
437
- // prefetch next x block
438
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
439
-
440
- // calculate
441
- float d = qb_curr.d ;
442
- float acc = sumy;
443
- for (int i = 0 ; i < 16 ; i++) {
444
- acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
449
+ for (int row = 0 ; row < N_DST; row+=2 ) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y (x + (nb_start + ib + (row + ir) * nb), sumy, yl);
445
452
}
446
- if (tiisg < nb % N_SIMDWIDTH) {
447
- sumf[row] += d * acc;
448
- }
449
- qb_curr = qb_next;
453
+ }
454
+ }
450
455
451
- all_sum = simd_sum (sumf[ row]);
452
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
453
- dst[r1*ne0 + ( r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
454
- }
456
+ for ( int row = 0 ; row < N_DST; ++ row) {
457
+ all_sum = simd_sum (sumf[ row]);
458
+ if (tiisg == 0 && (( r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
455
460
}
456
461
}
457
462
}
458
463
459
- kernel void kernel_mul_mat_q4_1_f32 (
464
+ kernel void kernel_mul_mat_q4_0_f32 (
460
465
device const void * src0,
461
466
device const float * src1,
462
467
device float * dst,
@@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
467
472
uint2 tgpig[[threadgroup_position_in_grid]],
468
473
uint tiisg[[thread_index_in_simdgroup]],
469
474
uint sgitg[[simdgroup_index_in_threadgroup]]) {
470
- const int nb = ne00/QK4_0;
471
- const int r0 = tgpig.x ;
472
- const int r1 = tgpig.y ;
473
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
474
- device const float * y = (device const float *) src1 + r1*ne10;
475
- block_q4_1 qb_curr, qb_next;
476
- float4 y_curr[8 ]; // src1 vector cache
477
- float sumf[N_DST]={0 .f }, all_sum;
478
- thread float * yl=(thread float *)y_curr;
479
-
480
- // bootstrap
481
- qb_curr = x[tiisg];
482
- // each thread in a SIMD group deals with 1 block.
483
- for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
484
-
485
- float sumy = 0 ;
486
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
487
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
488
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
489
- }
490
-
491
- for (int row = 0 ; row < N_DST; row++) {
492
- // prefetch next x block
493
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
494
-
495
- // calculate
496
- const float d = qb_curr.d ;
497
- const float m = qb_curr.m ;
498
- float acc = 0 .f ;
499
- for (int i = 0 ; i < 16 ; i++) {
500
- acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
501
- }
502
- sumf[row] += d * acc + m * sumy;
503
- qb_curr = qb_next;
504
- }
505
- }
506
-
507
- if (nb % N_SIMDWIDTH == 0 ) {
508
- for (int row = 0 ; row < N_DST; ++row) {
509
- all_sum = simd_sum (sumf[row]);
510
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
511
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
512
- }
513
- }
514
- } else {
515
-
516
- float sumy = 0 ;
517
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
518
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
519
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
520
- }
521
-
522
- for (int row = 0 ; row < N_DST; row++) {
523
- // prefetch next x block
524
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
525
-
526
- // calculate
527
- const float d = qb_curr.d ;
528
- const float m = qb_curr.m ;
529
- float acc = 0 .f ;
530
- for (int i = 0 ; i < 16 ; i++) {
531
- acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
532
- }
533
- if (tiisg < nb % N_SIMDWIDTH) {
534
- sumf[row] += d * acc + m * sumy;
535
- }
536
- qb_curr = qb_next;
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
537
477
538
- all_sum = simd_sum (sumf[row]);
539
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
540
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
541
- }
542
- }
543
- }
478
+ kernel void kernel_mul_mat_q4_1_f32 (
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4 )]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
544
490
}
545
491
546
492
kernel void kernel_mul_mat_f16_f32 (
0 commit comments