Skip to content

Commit 4088df1

Browse files
committed
metal: update rms_norm kernel
This commit double the speed of rms_norm operations by using 512 threads per threadgroup, combining with SIMD primitives to minimize the need for thread group barriers.
1 parent bbce392 commit 4088df1

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

ggml-metal.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,15 +792,15 @@ void ggml_metal_graph_compute(
792792

793793
const float eps = 1e-6f;
794794

795-
const int nth = 256;
795+
const int nth = 512;
796796

797797
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
798798
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
799799
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
800800
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
801801
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
802802
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
803-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
803+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
804804

805805
const int64_t nrows = ggml_nrows(src0);
806806

ggml-metal.metal

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -339,26 +339,33 @@ kernel void kernel_rms_norm(
339339
threadgroup float * sum [[threadgroup(0)]],
340340
uint tgpig[[threadgroup_position_in_grid]],
341341
uint tpitg[[thread_position_in_threadgroup]],
342+
uint sgitg[[simdgroup_index_in_threadgroup]],
343+
uint tiisg[[thread_index_in_simdgroup]],
342344
uint ntg[[threads_per_threadgroup]]) {
343-
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
345+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
346+
device const float * x_scalar = (device const float *) x;
347+
float4 sumf=0;
348+
float all_sum=0;
344349

345350
// parallel sum
346-
sum[tpitg] = 0.0f;
347-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
348-
sum[tpitg] += x[i00] * x[i00];
351+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
352+
sumf += x[i00] * x[i00];
353+
}
354+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
355+
all_sum = simd_sum(all_sum);
356+
if (tiisg == 0) {
357+
sum[sgitg] = all_sum;
349358
}
350359

351-
// reduce
352360
threadgroup_barrier(mem_flags::mem_threadgroup);
353-
for (uint i = ntg/2; i > 0; i /= 2) {
354-
if (tpitg < i) {
355-
sum[tpitg] += sum[tpitg + i];
356-
}
357-
threadgroup_barrier(mem_flags::mem_threadgroup);
361+
// broadcast, simd group number is ntg / 32
362+
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
363+
if (tpitg < i) {
364+
sum[tpitg] += sum[tpitg + i];
365+
}
358366
}
359-
360-
// broadcast
361367
if (tpitg == 0) {
368+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
362369
sum[0] /= ne00;
363370
}
364371

@@ -367,10 +374,14 @@ kernel void kernel_rms_norm(
367374
const float mean = sum[0];
368375
const float scale = 1.0f/sqrt(mean + eps);
369376

370-
device float * y = dst + tgpig*ne00;
371-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
378+
device float * y_scalar = (device float *) y;
379+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
372380
y[i00] = x[i00] * scale;
373381
}
382+
if (tpitg == 0) {
383+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
384+
}
374385
}
375386

376387
// putting them in the kernel cause a significant performance penalty

0 commit comments

Comments
 (0)