Skip to content

Commit 06658ad

Browse files
authored
metal : separate scale and mask from QKT in FA kernel (#9189)
* metal : separate scale and mask from QKT in FA kernel * metal : ne01 check no longer necessary * metal : keep data in local memory
1 parent fc18425 commit 06658ad

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

ggml/src/ggml-metal.metal

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
22612261
}
22622262

22632263
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2264-
2265-
const short tx = tiisg%4;
2266-
const short ty = tiisg/4;
2267-
2268-
// mqk = mqk*scale
2269-
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2270-
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2271-
2272-
if (logit_softcap != 0.0f) {
2273-
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
2274-
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
2275-
}
2276-
2277-
if (mask != q) {
2278-
// mqk = mqk + mask*slope
2279-
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2280-
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2281-
}
22822264
}
22832265
}
22842266

@@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
22902272
float ms[Q];
22912273

22922274
for (short j = 0; j < Q; ++j) {
2293-
const short p = tiisg;
2294-
22952275
const float m = M[j];
2296-
const float s = ss[j*TF + p];
2276+
2277+
// scale and apply the logitcap / mask
2278+
float s = ss[j*TF + tiisg]*scale;
2279+
2280+
if (logit_softcap != 0.0f) {
2281+
s = logit_softcap*precise::tanh(s);
2282+
}
2283+
2284+
if (mask != q) {
2285+
// mqk = mqk + mask*slope
2286+
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
2287+
}
22972288

22982289
smax = simd_max(max(smax, s));
22992290
M[j] = simd_max(max(M[j], s));
@@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
23042295
S[j] = S[j]*ms[j] + simd_sum(vs);
23052296

23062297
// the P matrix from the paper (Q rows, C columns)
2307-
ss[j*TF + p] = vs;
2298+
ss[j*TF + tiisg] = vs;
23082299
}
23092300

23102301
// create a QxQ diagonal matrix for rescaling the output

0 commit comments

Comments
 (0)