Skip to content

Commit ff1b4f5

Browse files
committed
metal : minor clean-up
1 parent f66d362 commit ff1b4f5

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

ggml/src/ggml-metal.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,12 +3292,12 @@ static void ggml_metal_encode_node(
32923292

32933293
// ne00 + 2*ncpsg*(nsg)
32943294
// for each query, we load it as f16 in shared memory (ne00)
3295-
// and store the attention scores (nqptg x ncpsg) as f32
3295+
// and store the soft_max values and the mask
32963296
//
32973297
// ne00*(nsg)
32983298
// each simdgroup has a full f16 head vector in shared mem to accumulate results
32993299
//
3300-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
3300+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
33013301

33023302
int64_t nsgmax = 2;
33033303

ggml/src/ggml-metal.metal

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3355,15 +3355,15 @@ kernel void kernel_flash_attn_ext_vec(
33553355
const short NW4 = NW/4;
33563356
const short SH = 2*C; // shared memory per simdgroup
33573357

3358-
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
3358+
const short T = D + nsg*SH; // shared memory size per query in (half)
33593359

3360-
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
3362-
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
3363-
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
3364-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in s4_t
3365-
threadgroup half * sm = (threadgroup half *) (shared + 2*sgitg*SH + SH + Q*D); // scratch buffer for mask
3366-
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3360+
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
3362+
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
3363+
threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
3364+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
3365+
threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
3366+
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33673367

33683368
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
33693369
o4x4_t lo[D16/NW4];
@@ -3522,7 +3522,7 @@ kernel void kernel_flash_attn_ext_vec(
35223522
for (short cc = 0; cc < C/4; ++cc) {
35233523
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
35243524

3525-
const v4x4_t ms(ss[4*cc + ty]);
3525+
const s4x4_t ms(ss[4*cc + ty]);
35263526

35273527
#pragma unroll
35283528
for (short ii = 0; ii < D16; ii += NW4) {
@@ -3531,7 +3531,7 @@ kernel void kernel_flash_attn_ext_vec(
35313531
v4x4_t mv;
35323532
deq_v(pv4 + i/nl_v, i%nl_v, mv);
35333533

3534-
lo[ii/NW4] += (o4x4_t)(mv*ms);
3534+
lo[ii/NW4] += mv*ms;
35353535
}
35363536
}
35373537
}
@@ -3616,12 +3616,15 @@ kernel void kernel_flash_attn_ext_vec(
36163616
}
36173617
}
36183618

3619+
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
3620+
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3621+
//
36193622
#define FA_TYPES \
3620-
half4, half4x4, \
3621-
half4x4, \
3622-
half4x4, \
3623-
float, \
3624-
float, float4, float4x4, \
3623+
half4, half4x4, \
3624+
half4x4, \
3625+
half4x4, \
3626+
float, \
3627+
half, half4, half4x4, \
36253628
half4x4
36263629

36273630
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;

0 commit comments

Comments
 (0)