Skip to content

Commit 55717c9

Browse files
committed
metal : warp-based reduction for soft max kernel
1 parent 68e02c0 commit 55717c9

File tree

2 files changed

+68
-61
lines changed

2 files changed

+68
-61
lines changed

ggml-metal.m

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
10281028
int nth = 32; // SIMD width
10291029

10301030
if (ne00%4 == 0) {
1031+
while (nth < ne00/4 && nth < 256) {
1032+
nth *= 2;
1033+
}
10311034
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
10321035
} else {
1033-
do {
1036+
while (nth < ne00 && nth < 1024) {
10341037
nth *= 2;
1035-
} while (nth <= ne00 && nth <= 1024);
1036-
nth /= 2;
1038+
}
10371039
[encoder setComputePipelineState:ctx->pipeline_soft_max];
10381040
}
10391041

@@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
10461048
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
10471049
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
10481050
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1049-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1051+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
10501052

10511053
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10521054
} break;

ggml-metal.metal

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ typedef struct {
3939
int8_t qs[QK8_0]; // quants
4040
} block_q8_0;
4141

42+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43+
4244
// general-purpose kernel for addition of two tensors
4345
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
4446
// cons: not very efficient
@@ -207,54 +209,55 @@ kernel void kernel_soft_max(
207209
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
208210
}
209211

210-
float max = simd_max(lmax);
211-
if (tiisg == 0) {
212-
buf[sgitg] = max;
213-
}
212+
// find the max value in the block
213+
float max_val = simd_max(lmax);
214+
if (ntg > N_SIMDWIDTH) {
215+
if (sgitg == 0) {
216+
buf[tiisg] = -INFINITY;
217+
}
214218

215-
threadgroup_barrier(mem_flags::mem_threadgroup);
219+
threadgroup_barrier(mem_flags::mem_threadgroup);
216220

217-
// broadcast, simd group number is ntg / 32
218-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
219-
if (tpitg < i) {
220-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
221-
}
222-
}
221+
if (tiisg == 0) {
222+
buf[sgitg] = max_val;
223+
}
223224

224-
threadgroup_barrier(mem_flags::mem_threadgroup);
225+
threadgroup_barrier(mem_flags::mem_threadgroup);
225226

226-
max = buf[0];
227+
max_val = buf[tiisg];
228+
max_val = simd_max(max_val);
229+
}
227230

228231
// parallel sum
229232
float lsum = 0.0f;
230233
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
231-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
234+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
232235
lsum += exp_psrc0;
233-
// Remember the result of exp here. exp is expensive, so we really do not
234-
// wish to compute it twice.
235236
pdst[i00] = exp_psrc0;
236237
}
237238

238239
float sum = simd_sum(lsum);
239-
if (tiisg == 0) {
240-
buf[sgitg] = sum;
241-
}
240+
if (ntg > N_SIMDWIDTH) {
241+
if (sgitg == 0) {
242+
buf[tiisg] = 0.0f;
243+
}
242244

243-
threadgroup_barrier(mem_flags::mem_threadgroup);
245+
threadgroup_barrier(mem_flags::mem_threadgroup);
244246

245-
// broadcast, simd group number is ntg / 32
246-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
247-
if (tpitg < i) {
248-
buf[tpitg] += buf[tpitg + i];
249-
}
250-
}
247+
if (tiisg == 0) {
248+
buf[sgitg] = sum;
249+
}
251250

252-
threadgroup_barrier(mem_flags::mem_threadgroup);
251+
threadgroup_barrier(mem_flags::mem_threadgroup);
252+
253+
sum = buf[tiisg];
254+
sum = simd_sum(sum);
255+
}
253256

254-
sum = buf[0];
257+
const float inv_sum = 1.0f/sum;
255258

256259
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
257-
pdst[i00] /= sum;
260+
pdst[i00] *= inv_sum;
258261
}
259262
}
260263

@@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
288291
}
289292

290293
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
291-
float max = simd_max(lmax);
292-
if (tiisg == 0) {
293-
buf[sgitg] = max;
294-
}
295294

296-
threadgroup_barrier(mem_flags::mem_threadgroup);
295+
float max_val = simd_max(lmax);
296+
if (ntg > N_SIMDWIDTH) {
297+
if (sgitg == 0) {
298+
buf[tiisg] = -INFINITY;
299+
}
297300

298-
// broadcast, simd group number is ntg / 32
299-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
300-
if (tpitg < i) {
301-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
302-
}
303-
}
301+
threadgroup_barrier(mem_flags::mem_threadgroup);
304302

305-
threadgroup_barrier(mem_flags::mem_threadgroup);
303+
if (tiisg == 0) {
304+
buf[sgitg] = max_val;
305+
}
306306

307-
max = buf[0];
307+
threadgroup_barrier(mem_flags::mem_threadgroup);
308+
309+
max_val = buf[tiisg];
310+
max_val = simd_max(max_val);
311+
}
308312

309313
// parallel sum
310314
float4 lsum4 = 0.0f;
311315
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
312-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
316+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
313317
lsum4 += exp_psrc4;
314318
pdst4[i00] = exp_psrc4;
315319
}
316320

317321
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
318322
float sum = simd_sum(lsum);
319-
if (tiisg == 0) {
320-
buf[sgitg] = sum;
321-
}
323+
if (ntg > N_SIMDWIDTH) {
324+
if (sgitg == 0) {
325+
buf[tiisg] = 0.0f;
326+
}
322327

323-
threadgroup_barrier(mem_flags::mem_threadgroup);
328+
threadgroup_barrier(mem_flags::mem_threadgroup);
324329

325-
// broadcast, simd group number is ntg / 32
326-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
327-
if (tpitg < i) {
328-
buf[tpitg] += buf[tpitg + i];
329-
}
330-
}
330+
if (tiisg == 0) {
331+
buf[sgitg] = sum;
332+
}
331333

332-
threadgroup_barrier(mem_flags::mem_threadgroup);
334+
threadgroup_barrier(mem_flags::mem_threadgroup);
335+
336+
sum = buf[tiisg];
337+
sum = simd_sum(sum);
338+
}
333339

334-
sum = buf[0];
340+
const float inv_sum = 1.0f/sum;
335341

336342
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
337-
pdst4[i00] /= sum;
343+
pdst4[i00] *= inv_sum;
338344
}
339345
}
340346

@@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582588
// putting them in the kernel cause a significant performance penalty
583589
#define N_DST 4 // each SIMD group works on 4 rows
584590
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
585-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
586591
//Note: This is a template, but strictly speaking it only applies to
587592
// quantizations where the block size is 32. It also does not
588593
// giard against the number of rows not being divisible by

0 commit comments

Comments
 (0)