@@ -339,26 +339,33 @@ kernel void kernel_rms_norm(
339
339
threadgroup float * sum [[threadgroup(0 )]],
340
340
uint tgpig[[threadgroup_position_in_grid]],
341
341
uint tpitg[[thread_position_in_threadgroup]],
342
+ uint sgitg[[simdgroup_index_in_threadgroup]],
343
+ uint tiisg[[thread_index_in_simdgroup]],
342
344
uint ntg[[threads_per_threadgroup]]) {
343
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
345
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
346
+ device const float * x_scalar = (device const float *) x;
347
+ float4 sumf=0 ;
348
+ float all_sum=0 ;
344
349
345
350
// parallel sum
346
- sum[tpitg] = 0 .0f ;
347
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
348
- sum[tpitg] += x[i00] * x[i00];
351
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
352
+ sumf += x[i00] * x[i00];
353
+ }
354
+ all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
355
+ all_sum = simd_sum (all_sum);
356
+ if (tiisg == 0 ) {
357
+ sum[sgitg] = all_sum;
349
358
}
350
359
351
- // reduce
352
360
threadgroup_barrier (mem_flags::mem_threadgroup);
353
- for ( uint i = ntg/ 2 ; i > 0 ; i /= 2 ) {
354
- if (tpitg < i ) {
355
- sum[ tpitg] += sum[tpitg + i];
356
- }
357
- threadgroup_barrier (mem_flags::mem_threadgroup);
361
+ // broadcast, simd group number is ntg / 32
362
+ for ( int i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
363
+ if ( tpitg < i) {
364
+ sum[tpitg] += sum[tpitg + i];
365
+ }
358
366
}
359
-
360
- // broadcast
361
367
if (tpitg == 0 ) {
368
+ for (int i = 4 * (ne00 / 4 ); i < ne00; i++) {sum[0 ] += x_scalar[i];}
362
369
sum[0 ] /= ne00;
363
370
}
364
371
@@ -367,10 +374,14 @@ kernel void kernel_rms_norm(
367
374
const float mean = sum[0 ];
368
375
const float scale = 1 .0f /sqrt (mean + eps);
369
376
370
- device float * y = dst + tgpig*ne00;
371
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
378
+ device float * y_scalar = (device float *) y;
379
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
372
380
y[i00] = x[i00] * scale;
373
381
}
382
+ if (tpitg == 0 ) {
383
+ for (int i00 = 4 * (ne00 / 4 ); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
384
+ }
374
385
}
375
386
376
387
// putting them in the kernel cause a significant performance penalty
0 commit comments