@@ -2774,13 +2774,13 @@ kernel void kernel_flash_attn_ext(
2774
2774
const short NW = N_SIMDWIDTH;
2775
2775
const short SH = (C + Q); // shared memory per simdgroup in (half)
2776
2776
2777
- const short T = D + 2 * nsg*SH; // shared memory size per query in (half)
2778
- const short TF = T/ 2 ; // shared memory size per query in (float)
2777
+ const short T = D + nsg*SH; // shared memory size per query in (half)
2778
+ const short TF = T; // shared memory size per query in (float)
2779
2779
const short T4 = T/4 ; // shared memory size per query in (half4)
2780
2780
2781
- threadgroup half * sq = (threadgroup half *) (shared + 0 *D); // holds the query data
2782
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2783
- threadgroup float * ss = (threadgroup float *) (shared + 2 * sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2781
+ threadgroup half * sq = (threadgroup half *) (shared + 0 *D); // holds the query data
2782
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2783
+ threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2784
2784
2785
2785
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K and V in shared memory
2786
2786
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *KV) + Q*T); // same as above but in half4x4
@@ -2809,7 +2809,7 @@ kernel void kernel_flash_attn_ext(
2809
2809
// zero out shared memory SH
2810
2810
for (short j = 0 ; j < Q; ++j) {
2811
2811
for (short i = tiisg; i < SH; i += NW) {
2812
- ss[j*TF + i] = 0 .0f ;
2812
+ ss[j*TF + i] = 0 .0h ;
2813
2813
}
2814
2814
}
2815
2815
@@ -2874,7 +2874,7 @@ kernel void kernel_flash_attn_ext(
2874
2874
// Q*K^T
2875
2875
{
2876
2876
for (short cc = 0 ; cc < C/8 ; ++cc) {
2877
- simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float , 8 >(0 .h );
2877
+ simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half , 8 >(0 .h );
2878
2878
2879
2879
// this is compile-time check, so it does not have runtime overhead
2880
2880
if (is_same<block_q, half4x4>::value) {
@@ -2946,7 +2946,7 @@ kernel void kernel_flash_attn_ext(
2946
2946
const float m = M[j];
2947
2947
2948
2948
// scale and apply the logitcap / mask
2949
- float s = ss[j*TF + tiisg]*scale;
2949
+ float s = (( float )( ss[j*TF + tiisg])) *scale;
2950
2950
2951
2951
if (logit_softcap != 0 .0f ) {
2952
2952
s = logit_softcap*precise::tanh (s);
@@ -2982,7 +2982,7 @@ kernel void kernel_flash_attn_ext(
2982
2982
2983
2983
// O = diag(ms)*O
2984
2984
{
2985
- simdgroup_float8x8 mm;
2985
+ simdgroup_half8x8 mm;
2986
2986
simdgroup_load (mm, ss + C, TF, 0 , false );
2987
2987
2988
2988
for (short i = 0 ; i < D8; ++i) {
@@ -2993,7 +2993,7 @@ kernel void kernel_flash_attn_ext(
2993
2993
// O = O + (Q*K^T)*V
2994
2994
{
2995
2995
for (short cc = 0 ; cc < C/8 ; ++cc) {
2996
- simdgroup_float8x8 ms;
2996
+ simdgroup_half8x8 ms;
2997
2997
simdgroup_load (ms, ss + 8 *cc, TF, 0 , false );
2998
2998
2999
2999
if (is_same<block_q, half4x4>::value) {
@@ -3106,8 +3106,8 @@ kernel void kernel_flash_attn_ext(
3106
3106
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3107
3107
{
3108
3108
simdgroup_half8x8 t;
3109
- simdgroup_float8x8 ms0;
3110
- simdgroup_float8x8 ms1;
3109
+ simdgroup_half8x8 ms0;
3110
+ simdgroup_half8x8 ms1;
3111
3111
3112
3112
simdgroup_load (ms0, ss + C, TF, 0 , false );
3113
3113
simdgroup_load (ms1, ss + C + sg*SH, TF, 0 , false );
@@ -3188,6 +3188,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_
3188
3188
template [[host_name(" kernel_flash_attn_ext_q8_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2 , dequantize_q8_0, 256 >;
3189
3189
3190
3190
// NOTE: can use half instead of float precision for some extra perf
3191
+ // however, by default use F32 since the op should be mostly memory bandwidth bound
3191
3192
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
3192
3193
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &), short D, short Q = 1 , short C = 32 >
3193
3194
kernel void kernel_flash_attn_ext_vec (
0 commit comments