Skip to content

Commit c4db592

Browse files
committed
metal : warp-based reduce for rms_norm
1 parent 55717c9 commit c4db592

File tree

2 files changed

+26
-33
lines changed

2 files changed

+26
-33
lines changed

ggml-metal.m

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,15 +1358,19 @@ void ggml_metal_graph_compute(
13581358
float eps;
13591359
memcpy(&eps, dst->op_params, sizeof(float));
13601360

1361-
const int nth = MIN(512, ne00);
1361+
int nth = 32; // SIMD width
1362+
1363+
while (nth < ne00/4 && nth < 1024) {
1364+
nth *= 2;
1365+
}
13621366

13631367
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1364-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1365-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1366-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1367-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1368-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1369-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1368+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1369+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1370+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1371+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1372+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1373+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
13701374

13711375
const int64_t nrows = ggml_nrows(src0);
13721376

ggml-metal.metal

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,13 @@ kernel void kernel_rms_norm(
447447
constant int64_t & ne00,
448448
constant uint64_t & nb01,
449449
constant float & eps,
450-
threadgroup float * sum [[threadgroup(0)]],
450+
threadgroup float * buf [[threadgroup(0)]],
451451
uint tgpig[[threadgroup_position_in_grid]],
452452
uint tpitg[[thread_position_in_threadgroup]],
453453
uint sgitg[[simdgroup_index_in_threadgroup]],
454454
uint tiisg[[thread_index_in_simdgroup]],
455455
uint ntg[[threads_per_threadgroup]]) {
456-
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
457-
device const float * x_scalar = (device const float *) x;
456+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
458457

459458
float4 sumf = 0;
460459
float all_sum = 0;
@@ -465,40 +464,30 @@ kernel void kernel_rms_norm(
465464
}
466465
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
467466
all_sum = simd_sum(all_sum);
468-
if (tiisg == 0) {
469-
sum[sgitg] = all_sum;
470-
}
467+
if (ntg > N_SIMDWIDTH) {
468+
if (sgitg == 0) {
469+
buf[tiisg] = 0.0f;
470+
}
471471

472-
threadgroup_barrier(mem_flags::mem_threadgroup);
472+
threadgroup_barrier(mem_flags::mem_threadgroup);
473473

474-
// broadcast, simd group number is ntg / 32
475-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
476-
if (tpitg < i) {
477-
sum[tpitg] += sum[tpitg + i];
478-
}
479-
}
480-
if (tpitg == 0) {
481-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
482-
sum[0] += x_scalar[i];
474+
if (tiisg == 0) {
475+
buf[sgitg] = all_sum;
483476
}
484-
sum[0] /= ne00;
485-
}
486477

487-
threadgroup_barrier(mem_flags::mem_threadgroup);
478+
threadgroup_barrier(mem_flags::mem_threadgroup);
488479

489-
const float mean = sum[0];
480+
all_sum = buf[tiisg];
481+
all_sum = simd_sum(all_sum);
482+
}
483+
484+
const float mean = all_sum/ne00;
490485
const float scale = 1.0f/sqrt(mean + eps);
491486

492487
device float4 * y = (device float4 *) (dst + tgpig*ne00);
493-
device float * y_scalar = (device float *) y;
494488
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
495489
y[i00] = x[i00] * scale;
496490
}
497-
if (tpitg == 0) {
498-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
499-
y_scalar[i00] = x_scalar[i00] * scale;
500-
}
501-
}
502491
}
503492

504493
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])

0 commit comments

Comments
 (0)