@@ -39,6 +39,8 @@ typedef struct {
39
39
int8_t qs[QK8_0]; // quants
40
40
} block_q8_0;
41
41
42
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
+
42
44
// general-purpose kernel for addition of two tensors
43
45
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
46
// cons: not very efficient
@@ -207,54 +209,55 @@ kernel void kernel_soft_max(
207
209
lmax = MAX (lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f ));
208
210
}
209
211
210
- float max = simd_max (lmax);
211
- if (tiisg == 0 ) {
212
- buf[sgitg] = max;
213
- }
212
+ // find the max value in the block
213
+ float max_val = simd_max (lmax);
214
+ if (ntg > N_SIMDWIDTH) {
215
+ if (sgitg == 0 ) {
216
+ buf[tiisg] = -INFINITY;
217
+ }
214
218
215
- threadgroup_barrier (mem_flags::mem_threadgroup);
219
+ threadgroup_barrier (mem_flags::mem_threadgroup);
216
220
217
- // broadcast, simd group number is ntg / 32
218
- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
219
- if (tpitg < i) {
220
- buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
221
- }
222
- }
221
+ if (tiisg == 0 ) {
222
+ buf[sgitg] = max_val;
223
+ }
223
224
224
- threadgroup_barrier (mem_flags::mem_threadgroup);
225
+ threadgroup_barrier (mem_flags::mem_threadgroup);
225
226
226
- max = buf[0 ];
227
+ max_val = buf[tiisg];
228
+ max_val = simd_max (max_val);
229
+ }
227
230
228
231
// parallel sum
229
232
float lsum = 0 .0f ;
230
233
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
231
- const float exp_psrc0 = exp ((psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max );
234
+ const float exp_psrc0 = exp ((psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max_val );
232
235
lsum += exp_psrc0;
233
- // Remember the result of exp here. exp is expensive, so we really do not
234
- // wish to compute it twice.
235
236
pdst[i00] = exp_psrc0;
236
237
}
237
238
238
239
float sum = simd_sum (lsum);
239
- if (tiisg == 0 ) {
240
- buf[sgitg] = sum;
241
- }
240
+ if (ntg > N_SIMDWIDTH) {
241
+ if (sgitg == 0 ) {
242
+ buf[tiisg] = 0 .0f ;
243
+ }
242
244
243
- threadgroup_barrier (mem_flags::mem_threadgroup);
245
+ threadgroup_barrier (mem_flags::mem_threadgroup);
244
246
245
- // broadcast, simd group number is ntg / 32
246
- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
247
- if (tpitg < i) {
248
- buf[tpitg] += buf[tpitg + i];
249
- }
250
- }
247
+ if (tiisg == 0 ) {
248
+ buf[sgitg] = sum;
249
+ }
251
250
252
- threadgroup_barrier (mem_flags::mem_threadgroup);
251
+ threadgroup_barrier (mem_flags::mem_threadgroup);
252
+
253
+ sum = buf[tiisg];
254
+ sum = simd_sum (sum);
255
+ }
253
256
254
- sum = buf[ 0 ] ;
257
+ const float inv_sum = 1 . 0f /sum ;
255
258
256
259
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
257
- pdst[i00] /= sum ;
260
+ pdst[i00] *= inv_sum ;
258
261
}
259
262
}
260
263
@@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
288
291
}
289
292
290
293
const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
291
- float max = simd_max (lmax);
292
- if (tiisg == 0 ) {
293
- buf[sgitg] = max;
294
- }
295
294
296
- threadgroup_barrier (mem_flags::mem_threadgroup);
295
+ float max_val = simd_max (lmax);
296
+ if (ntg > N_SIMDWIDTH) {
297
+ if (sgitg == 0 ) {
298
+ buf[tiisg] = -INFINITY;
299
+ }
297
300
298
- // broadcast, simd group number is ntg / 32
299
- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
300
- if (tpitg < i) {
301
- buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
302
- }
303
- }
301
+ threadgroup_barrier (mem_flags::mem_threadgroup);
304
302
305
- threadgroup_barrier (mem_flags::mem_threadgroup);
303
+ if (tiisg == 0 ) {
304
+ buf[sgitg] = max_val;
305
+ }
306
306
307
- max = buf[0 ];
307
+ threadgroup_barrier (mem_flags::mem_threadgroup);
308
+
309
+ max_val = buf[tiisg];
310
+ max_val = simd_max (max_val);
311
+ }
308
312
309
313
// parallel sum
310
314
float4 lsum4 = 0 .0f ;
311
315
for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
312
- const float4 exp_psrc4 = exp ((psrc4[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max );
316
+ const float4 exp_psrc4 = exp ((psrc4[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max_val );
313
317
lsum4 += exp_psrc4;
314
318
pdst4[i00] = exp_psrc4;
315
319
}
316
320
317
321
const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
318
322
float sum = simd_sum (lsum);
319
- if (tiisg == 0 ) {
320
- buf[sgitg] = sum;
321
- }
323
+ if (ntg > N_SIMDWIDTH) {
324
+ if (sgitg == 0 ) {
325
+ buf[tiisg] = 0 .0f ;
326
+ }
322
327
323
- threadgroup_barrier (mem_flags::mem_threadgroup);
328
+ threadgroup_barrier (mem_flags::mem_threadgroup);
324
329
325
- // broadcast, simd group number is ntg / 32
326
- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
327
- if (tpitg < i) {
328
- buf[tpitg] += buf[tpitg + i];
329
- }
330
- }
330
+ if (tiisg == 0 ) {
331
+ buf[sgitg] = sum;
332
+ }
331
333
332
- threadgroup_barrier (mem_flags::mem_threadgroup);
334
+ threadgroup_barrier (mem_flags::mem_threadgroup);
335
+
336
+ sum = buf[tiisg];
337
+ sum = simd_sum (sum);
338
+ }
333
339
334
- sum = buf[ 0 ] ;
340
+ const float inv_sum = 1 . 0f /sum ;
335
341
336
342
for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
337
- pdst4[i00] /= sum ;
343
+ pdst4[i00] *= inv_sum ;
338
344
}
339
345
}
340
346
@@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582
588
// putting them in the kernel cause a significant performance penalty
583
589
#define N_DST 4 // each SIMD group works on 4 rows
584
590
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
585
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
586
591
// Note: This is a template, but strictly speaking it only applies to
587
592
// quantizations where the block size is 32. It also does not
588
593
// giard against the number of rows not being divisible by
0 commit comments