@@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
2261
2261
}
2262
2262
2263
2263
simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2264
-
2265
- const short tx = tiisg%4 ;
2266
- const short ty = tiisg/4 ;
2267
-
2268
- // mqk = mqk*scale
2269
- ss[8 *cc + ty*TF + 2 *tx + 0 ] *= scale;
2270
- ss[8 *cc + ty*TF + 2 *tx + 1 ] *= scale;
2271
-
2272
- if (logit_softcap != 0 .0f ) {
2273
- ss[8 *cc + ty*TF + 2 *tx + 0 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 0 ]);
2274
- ss[8 *cc + ty*TF + 2 *tx + 1 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 1 ]);
2275
- }
2276
-
2277
- if (mask != q) {
2278
- // mqk = mqk + mask*slope
2279
- ss[8 *cc + ty*TF + 2 *tx + 0 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 0 ];
2280
- ss[8 *cc + ty*TF + 2 *tx + 1 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 1 ];
2281
- }
2282
2264
}
2283
2265
}
2284
2266
@@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
2290
2272
float ms[Q];
2291
2273
2292
2274
for (short j = 0 ; j < Q; ++j) {
2293
- const short p = tiisg;
2294
-
2295
2275
const float m = M[j];
2296
- const float s = ss[j*TF + p];
2276
+
2277
+ // scale and apply the logitcap / mask
2278
+ float s = ss[j*TF + tiisg]*scale;
2279
+
2280
+ if (logit_softcap != 0 .0f ) {
2281
+ s = logit_softcap*precise::tanh (s);
2282
+ }
2283
+
2284
+ if (mask != q) {
2285
+ // mqk = mqk + mask*slope
2286
+ s += slope*mp[ic + j*nb31/sizeof (half) + tiisg];
2287
+ }
2297
2288
2298
2289
smax = simd_max (max (smax, s));
2299
2290
M[j] = simd_max (max (M[j], s));
@@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
2304
2295
S[j] = S[j]*ms[j] + simd_sum (vs);
2305
2296
2306
2297
// the P matrix from the paper (Q rows, C columns)
2307
- ss[j*TF + p ] = vs;
2298
+ ss[j*TF + tiisg ] = vs;
2308
2299
}
2309
2300
2310
2301
// create a QxQ diagonal matrix for rescaling the output
0 commit comments