Skip to content

Commit c71e0bc

Browse files
committed
metal : use F16 precision in FA kernel
1 parent 1e12961 commit c71e0bc

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

ggml/src/ggml-metal.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3150,7 +3150,7 @@ static void ggml_metal_encode_node(
31503150
while (true) {
31513151
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
31523152
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153-
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
3153+
const size_t smem = (nqptg*(ne00 + nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
31543154
if (smem > device.maxThreadgroupMemoryLength) {
31553155
break;
31563156
}
@@ -3161,7 +3161,7 @@ static void ggml_metal_encode_node(
31613161
// simdgroups per threadgroup (a.k.a. warps)
31623162
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
31633163

3164-
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
3164+
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
31653165

31663166
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
31673167
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

ggml/src/ggml-metal.metal

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2774,13 +2774,13 @@ kernel void kernel_flash_attn_ext(
27742774
const short NW = N_SIMDWIDTH;
27752775
const short SH = (C + Q); // shared memory per simdgroup in (half)
27762776

2777-
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2778-
const short TF = T/2; // shared memory size per query in (float)
2777+
const short T = D + nsg*SH; // shared memory size per query in (half)
2778+
const short TF = T; // shared memory size per query in (float)
27792779
const short T4 = T/4; // shared memory size per query in (half4)
27802780

2781-
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2782-
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2783-
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2781+
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2782+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2783+
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
27842784

27852785
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
27862786
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
@@ -2809,7 +2809,7 @@ kernel void kernel_flash_attn_ext(
28092809
// zero out shared memory SH
28102810
for (short j = 0; j < Q; ++j) {
28112811
for (short i = tiisg; i < SH; i += NW) {
2812-
ss[j*TF + i] = 0.0f;
2812+
ss[j*TF + i] = 0.0h;
28132813
}
28142814
}
28152815

@@ -2874,7 +2874,7 @@ kernel void kernel_flash_attn_ext(
28742874
// Q*K^T
28752875
{
28762876
for (short cc = 0; cc < C/8; ++cc) {
2877-
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2877+
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h);
28782878

28792879
if (is_same<block_q, half4x4>::value) {
28802880
// we can read directly from global memory
@@ -2944,7 +2944,7 @@ kernel void kernel_flash_attn_ext(
29442944
const float m = M[j];
29452945

29462946
// scale and apply the logitcap / mask
2947-
float s = ss[j*TF + tiisg]*scale;
2947+
float s = ((float)(ss[j*TF + tiisg]))*scale;
29482948

29492949
if (logit_softcap != 0.0f) {
29502950
s = logit_softcap*precise::tanh(s);
@@ -2980,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
29802980

29812981
// O = diag(ms)*O
29822982
{
2983-
simdgroup_float8x8 mm;
2983+
simdgroup_half8x8 mm;
29842984
simdgroup_load(mm, ss + C, TF, 0, false);
29852985

29862986
for (short i = 0; i < D8; ++i) {
@@ -2991,7 +2991,7 @@ kernel void kernel_flash_attn_ext(
29912991
// O = O + (Q*K^T)*V
29922992
{
29932993
for (short cc = 0; cc < C/8; ++cc) {
2994-
simdgroup_float8x8 ms;
2994+
simdgroup_half8x8 ms;
29952995
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
29962996

29972997
if (is_same<block_q, half4x4>::value) {
@@ -3103,8 +3103,8 @@ kernel void kernel_flash_attn_ext(
31033103
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
31043104
{
31053105
simdgroup_half8x8 t;
3106-
simdgroup_float8x8 ms0;
3107-
simdgroup_float8x8 ms1;
3106+
simdgroup_half8x8 ms0;
3107+
simdgroup_half8x8 ms1;
31083108

31093109
simdgroup_load(ms0, ss + C, TF, 0, false);
31103110
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);

0 commit comments

Comments
 (0)