@@ -3355,15 +3355,15 @@ kernel void kernel_flash_attn_ext_vec(
3355
3355
const short NW4 = NW/4 ;
3356
3356
const short SH = 2 *C; // shared memory per simdgroup
3357
3357
3358
- const short T = D + 2 * nsg*SH; // shared memory size per query in (half)
3358
+ const short T = D + nsg*SH; // shared memory size per query in (half)
3359
3359
3360
- // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in q4_t
3362
- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in q4x4_t
3363
- threadgroup s_t * ss = (threadgroup s_t *) (shared + 2 * sgitg*SH + Q*D); // scratch buffer for attention
3364
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2 * sgitg*SH + Q*D); // same as above but in s4_t
3365
- threadgroup half * sm = (threadgroup half *) (shared + 2 * sgitg*SH + SH + Q*D); // scratch buffer for mask
3366
- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3360
+ // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in q4_t
3362
+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in q4x4_t
3363
+ threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
3364
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
3365
+ threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
3366
+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3367
3367
3368
3368
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3369
3369
o4x4_t lo[D16/NW4];
@@ -3522,7 +3522,7 @@ kernel void kernel_flash_attn_ext_vec(
3522
3522
for (short cc = 0 ; cc < C/4 ; ++cc) {
3523
3523
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3524
3524
3525
- const v4x4_t ms (ss[4 *cc + ty]);
3525
+ const s4x4_t ms (ss[4 *cc + ty]);
3526
3526
3527
3527
#pragma unroll
3528
3528
for (short ii = 0 ; ii < D16; ii += NW4) {
@@ -3531,7 +3531,7 @@ kernel void kernel_flash_attn_ext_vec(
3531
3531
v4x4_t mv;
3532
3532
deq_v (pv4 + i/nl_v, i%nl_v, mv);
3533
3533
3534
- lo[ii/NW4] += ( o4x4_t )( mv*ms) ;
3534
+ lo[ii/NW4] += mv*ms;
3535
3535
}
3536
3536
}
3537
3537
}
@@ -3616,12 +3616,15 @@ kernel void kernel_flash_attn_ext_vec(
3616
3616
}
3617
3617
}
3618
3618
3619
+ // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
3620
+ // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3621
+ //
3619
3622
#define FA_TYPES \
3620
- half4, half4x4, \
3621
- half4x4, \
3622
- half4x4, \
3623
- float , \
3624
- float , float4, float4x4 , \
3623
+ half4, half4x4, \
3624
+ half4x4, \
3625
+ half4x4, \
3626
+ float , \
3627
+ half, half4, half4x4 , \
3625
3628
half4x4
3626
3629
3627
3630
typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
0 commit comments