@@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
387
387
}
388
388
}
389
389
390
- // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
391
- float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
390
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
392
+ // we assume that the yl's have been multiplied with the appropriate scale factor
393
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394
+ inline float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
392
395
float d = qb_curr->d ;
393
- float4 acc = 0 .f ;
394
- device uint16_t * qs = ((device uint16_t *)qb_curr + 1 );
395
- for (int i = 0 ; i < 16 ; i+=2 ) {
396
- acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
397
- acc[ 1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
398
- acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
399
- acc[ 3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
396
+ float2 acc = 0 .f ;
397
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/ 2 );
398
+ for (int i = 0 ; i < 8 ; i+=2 ) {
399
+ acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
400
+ + yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
401
+ acc[1 ] += yl[i + 8 ] * (qs[i / 2 ] & 0x00F0 )
402
+ + yl[i + 9 ] * (qs[i / 2 ] & 0xF000 );
400
403
}
401
- return d * (sumy * -8 .f + acc[0 ] + acc[1 ]/ 16 . f + acc[ 2 ]/ 256 . f + acc[ 3 ]/ 4096 . f );
404
+ return d * (sumy * -8 .f + acc[0 ] + acc[1 ]);
402
405
}
403
406
404
- // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
405
- float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
407
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
409
+ // we assume that the yl's have been multiplied with the appropriate scale factor
410
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411
+ inline float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
406
412
float d = qb_curr->d ;
407
413
float m = qb_curr->m ;
408
- float4 acc = 0 . f ;
409
- device uint16_t * qs = ((device uint16_t *)qb_curr + 2 ) ;
410
- for (int i = 0 ; i < 16 ; i+=2 ) {
411
- acc[0 ] += yl[i] * (qs[i / 2 ] & 0x000F );
412
- acc[ 1 ] += yl[i + 16 ] * (qs[i / 2 ] & 0x00F0 );
413
- acc[2 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
414
- acc[ 3 ] += yl[i + 17 ] * (qs[i / 2 ] & 0xF000 );
414
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/ 2 ) ;
415
+ float2 acc = 0 . f ;
416
+ for (int i = 0 ; i < 8 ; i+=2 ) {
417
+ acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
418
+ + yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
419
+ acc[1 ] += yl[i + 8 ] * (qs[i / 2 ] & 0x00F0 )
420
+ + yl[i + 9 ] * (qs[i / 2 ] & 0xF000 );
415
421
}
416
- return d * (acc[0 ] + acc[1 ]/ 16 . f + acc[ 2 ]/ 256 . f + acc[ 3 ]/ 4096 . f ) + sumy * m;
422
+ return d * (acc[0 ] + acc[1 ]) + sumy * m;
417
423
}
418
424
419
425
// putting them in the kernel cause a significant performance penalty
420
426
#define N_DST 4 // each SIMD group works on 4 rows
421
427
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
422
428
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
423
- template <typename block_q_type>
429
+ // Note: This is a template, but strictly speaking it only applies to
430
+ // quantizations where the block size is 32. It also does not
431
+ // giard against the number of rows not being divisible by
432
+ // N_DST, so this is another explicit assumption of the implementation.
433
+ template <typename block_q_type, int nr, int nsg, int nw>
424
434
void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
425
435
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
426
436
uint2 tgpig, uint tiisg, uint sgitg) {
427
437
const int nb = ne00/QK4_0;
428
438
const int r0 = tgpig.x ;
429
439
const int r1 = tgpig.y ;
430
- device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440
+ const int first_row = (r0 * nsg + sgitg) * nr;
441
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
431
442
device const float * y = (device const float *) src1 + r1*ne10;
432
- float4 y_curr[8 ]; // src1 vector cache
433
- float sumf[N_DST]={0 .f }, all_sum;
434
- thread float * yl=(thread float *)y_curr;
443
+ float yl[16 ]; // src1 vector cache
444
+ float sumf[nr]={0 .f };
435
445
436
- // each thread in a SIMD group deals with 1 block.
437
- for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
438
- float sumy = 0 ;
439
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
440
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
441
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
442
- }
446
+ const int ix = tiisg/2 ;
447
+ const int il = 8 *(tiisg%2 );
443
448
444
- for (int row = 0 ; row < N_DST; row++) {
445
- sumf[row] += block_q_n_dot_y (x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
446
- }
447
- }
449
+ device const float * yb = y + ix * QK4_0 + il;
448
450
449
- // from now loads two rows every time and 16 blocks per row
450
- int ir = tiisg / (N_SIMDWIDTH / 2 );
451
- int ib = tiisg % (N_SIMDWIDTH / 2 );
452
- for (int ind = 0 ; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1 )/(N_SIMDWIDTH / 2 ); ind++) {
453
- int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2 ); // where the left blocks start
451
+ // each thread in a SIMD group deals with half a block.
452
+ for (int ib = ix; ib < nb; ib += nw/2 ) {
454
453
float sumy = 0 ;
455
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
456
- y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
457
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
454
+ for (int i = 0 ; i < 8 ; i += 2 ) {
455
+ sumy += yb[i] + yb[i+1 ];
456
+ yl[i+0 ] = yb[i+ 0 ];
457
+ yl[i+1 ] = yb[i+ 1 ]/256 .f ;
458
+ sumy += yb[i+16 ] + yb[i+17 ];
459
+ yl[i+8 ] = yb[i+16 ]/16 .f ;
460
+ yl[i+9 ] = yb[i+17 ]/4096 .f ;
458
461
}
459
462
460
- for (int row = 0 ; row < N_DST; row+=2 ) {
461
- if (nb_start + ib < nb) {
462
- sumf[row + ir] += block_q_n_dot_y (x + (nb_start + ib + (row + ir) * nb), sumy, yl);
463
- }
463
+ for (int row = 0 ; row < nr; row++) {
464
+ sumf[row] += block_q_n_dot_y (x+ib+row*nb, sumy, yl, il);
464
465
}
466
+
467
+ yb += QK4_0 * 16 ;
465
468
}
466
469
467
- for (int row = 0 ; row < N_DST ; ++row) {
468
- all_sum = simd_sum (sumf[row]);
469
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
470
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum ;
470
+ for (int row = 0 ; row < nr ; ++row) {
471
+ const float tot = simd_sum (sumf[row]);
472
+ if (tiisg == 0 && first_row + row < ne01) {
473
+ dst[r1*ne0 + first_row + row] = tot ;
471
474
}
472
475
}
473
476
}
@@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
483
486
uint2 tgpig[[threadgroup_position_in_grid]],
484
487
uint tiisg[[thread_index_in_simdgroup]],
485
488
uint sgitg[[simdgroup_index_in_threadgroup]]) {
486
- mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
489
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH >(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
487
490
}
488
491
489
492
kernel void kernel_mul_mat_q4_1_f32 (
@@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
497
500
uint2 tgpig[[threadgroup_position_in_grid]],
498
501
uint tiisg[[thread_index_in_simdgroup]],
499
502
uint sgitg[[simdgroup_index_in_threadgroup]]) {
500
- mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH >(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
501
504
}
502
505
503
506
kernel void kernel_mul_mat_f16_f32 (
0 commit comments