@@ -447,14 +447,13 @@ kernel void kernel_rms_norm(
447
447
constant int64_t & ne00,
448
448
constant uint64_t & nb01,
449
449
constant float & eps,
450
- threadgroup float * sum [[threadgroup(0 )]],
450
+ threadgroup float * buf [[threadgroup(0 )]],
451
451
uint tgpig[[threadgroup_position_in_grid]],
452
452
uint tpitg[[thread_position_in_threadgroup]],
453
453
uint sgitg[[simdgroup_index_in_threadgroup]],
454
454
uint tiisg[[thread_index_in_simdgroup]],
455
455
uint ntg[[threads_per_threadgroup]]) {
456
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
457
- device const float * x_scalar = (device const float *) x;
456
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
458
457
459
458
float4 sumf = 0 ;
460
459
float all_sum = 0 ;
@@ -465,40 +464,30 @@ kernel void kernel_rms_norm(
465
464
}
466
465
all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
467
466
all_sum = simd_sum (all_sum);
468
- if (tiisg == 0 ) {
469
- sum[sgitg] = all_sum;
470
- }
467
+ if (ntg > N_SIMDWIDTH) {
468
+ if (sgitg == 0 ) {
469
+ buf[tiisg] = 0 .0f ;
470
+ }
471
471
472
- threadgroup_barrier (mem_flags::mem_threadgroup);
472
+ threadgroup_barrier (mem_flags::mem_threadgroup);
473
473
474
- // broadcast, simd group number is ntg / 32
475
- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
476
- if (tpitg < i) {
477
- sum[tpitg] += sum[tpitg + i];
478
- }
479
- }
480
- if (tpitg == 0 ) {
481
- for (int i = 4 * (ne00 / 4 ); i < ne00; i++) {
482
- sum[0 ] += x_scalar[i];
474
+ if (tiisg == 0 ) {
475
+ buf[sgitg] = all_sum;
483
476
}
484
- sum[0 ] /= ne00;
485
- }
486
477
487
- threadgroup_barrier (mem_flags::mem_threadgroup);
478
+ threadgroup_barrier (mem_flags::mem_threadgroup);
488
479
489
- const float mean = sum[0 ];
480
+ all_sum = buf[tiisg];
481
+ all_sum = simd_sum (all_sum);
482
+ }
483
+
484
+ const float mean = all_sum/ne00;
490
485
const float scale = 1 .0f /sqrt (mean + eps);
491
486
492
487
device float4 * y = (device float4 *) (dst + tgpig*ne00);
493
- device float * y_scalar = (device float *) y;
494
488
for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
495
489
y[i00] = x[i00] * scale;
496
490
}
497
- if (tpitg == 0 ) {
498
- for (int i00 = 4 * (ne00 / 4 ); i00 < ne00; i00++) {
499
- y_scalar[i00] = x_scalar[i00] * scale;
500
- }
501
- }
502
491
}
503
492
504
493
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
0 commit comments