@@ -365,6 +365,82 @@ kernel void kernel_rms_norm(
365
365
}
366
366
}
367
367
368
+ // putting them in the kernel cause a significant performance penalty
369
+ #define N_DST 4 // each SIMD group works on 4 rows
370
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
+ kernel void kernel_mul_mat_vec_q4_0_f32 (
373
+ device const void * src0,
374
+ device const float * src1,
375
+ device float * dst,
376
+ constant int64_t & ne00,
377
+ constant int64_t & ne10,
378
+ constant int64_t & ne0,
379
+ uint2 tgpig[[threadgroup_position_in_grid]],
380
+ uint tiisg[[thread_index_in_simdgroup]],
381
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
382
+ const int nb = ne00/QK4_0;
383
+ const int r0 = tgpig.x ;
384
+ const int r1 = tgpig.y ;
385
+ device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
386
+ device const float * y = (device const float *) src1 + r1*ne10;
387
+ block_q4_0 qb_curr, qb_next;
388
+ float4 y_curr[8 ]; // src1 vector cache
389
+ float sumf[N_DST]={0 .f }, all_sum;
390
+ thread float * yl=(thread float *)y_curr;
391
+
392
+ // bootstrap
393
+ qb_curr = x[tiisg];
394
+ // each thread in a SIMD group deals with 1 block.
395
+ for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
396
+
397
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
398
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
399
+ }
400
+
401
+ for (int row = 0 ; row < N_DST; row++) {
402
+ // prefetch next x block
403
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
404
+
405
+ // calculate
406
+ float d = qb_curr.d ;
407
+ float2 acc = {0 .0f , 0 .0f };
408
+ for (int i = 0 ; i < 16 ; i++) {
409
+ acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
410
+ acc[1 ] += yl[i] + yl[i+16 ];
411
+ }
412
+ sumf[row] += d * (acc[0 ] - 8 .f *acc[1 ]);
413
+ qb_curr = qb_next;
414
+ }
415
+ }
416
+
417
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
418
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
419
+ }
420
+
421
+ for (int row = 0 ; row < N_DST; row++) {
422
+ // prefetch next x block
423
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
424
+
425
+ // calculate
426
+ float d = qb_curr.d ;
427
+ float2 acc = {0 .0f , 0 .0f };
428
+ for (int i = 0 ; i < 16 ; i++) {
429
+ acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
430
+ acc[1 ] += yl[i] + yl[i+16 ];
431
+ }
432
+ if (tiisg < nb % N_SIMDWIDTH) {
433
+ sumf[row] += d * (acc[0 ] - 8 .f *acc[1 ]);
434
+ }
435
+ qb_curr = qb_next;
436
+
437
+ all_sum = simd_sum (sumf[row]);
438
+ if (tiisg == 0 ) {
439
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
440
+ }
441
+ }
442
+ }
443
+
368
444
kernel void kernel_mul_mat_q4_0_f32 (
369
445
device const void * src0,
370
446
device const float * src1,
0 commit comments