@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
2942
2942
half smax = -INFINITY;
2943
2943
2944
2944
// load the mask in shared memory
2945
+ #pragma unroll(Q)
2945
2946
for (short j = 0 ; j < Q; ++j) {
2946
2947
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
2947
2948
@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
2968
2969
// we can read directly from global memory
2969
2970
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8 *cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
2970
2971
2971
- #pragma unroll
2972
+ #pragma unroll(D8)
2972
2973
for (short i = 0 ; i < D8; ++i) {
2973
2974
k8x8_t mk;
2974
2975
simdgroup_load (mk, pk + i*8 , nb_12_1/sizeof (k_t ), 0 , true ); // transpose // TODO: use ne10
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
2989
2990
2990
2991
simdgroup_barrier (mem_flags::mem_threadgroup);
2991
2992
2992
- #pragma unroll
2993
+ #pragma unroll(4)
2993
2994
for (short k = 0 ; k < 4 ; ++k) {
2994
2995
k8x8_t mk;
2995
2996
@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
3067
3068
s8x8_t mm;
3068
3069
simdgroup_load (mm, ss + 2 *C, TS, 0 , false );
3069
3070
3070
- #pragma unroll
3071
+ #pragma unroll(D8)
3071
3072
for (short i = 0 ; i < D8; ++i) {
3072
3073
simdgroup_multiply (lo[i], mm, lo[i]);
3073
3074
}
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
3082
3083
if (is_same<vd4x4_t , v4x4_t >::value) {
3083
3084
// we can read directly from global memory
3084
3085
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8 *cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3085
- #pragma unroll
3086
+
3087
+ #pragma unroll(D8)
3086
3088
for (short i = 0 ; i < D8; ++i) {
3087
3089
v8x8_t mv;
3088
3090
simdgroup_load (mv, pv + i*8 , nb_12_1/sizeof (v_t ), 0 , false ); // TODO: use ne20
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
3103
3105
3104
3106
simdgroup_barrier (mem_flags::mem_threadgroup);
3105
3107
3106
- #pragma unroll
3108
+ #pragma unroll(4)
3107
3109
for (short k = 0 ; k < 4 ; ++k) {
3108
3110
v8x8_t mv;
3109
3111
@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
3196
3198
simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
3197
3199
simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
3198
3200
3201
+ #pragma unroll(D8)
3199
3202
for (short i = 0 ; i < D8; ++i) {
3200
3203
o8x8_t t;
3201
3204
@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
3413
3416
// load the queries from shared memory into local memory
3414
3417
q4x4_t mq[D16/NL];
3415
3418
3419
+ #pragma unroll(D16/NL)
3416
3420
for (short ii = 0 ; ii < D16; ii += NL) {
3417
3421
mq[ii/NL] = sq4x4[ii + tx];
3418
3422
}
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
3454
3458
3455
3459
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
3460
3457
- #pragma unroll
3461
+ #pragma unroll(D16/NL)
3458
3462
for (short ii = 0 ; ii < D16; ii += NL) {
3459
3463
const short i = ii + tx;
3460
3464
3461
3465
k4x4_t mk;
3462
3466
deq_k (pk + i/nl_k, i%nl_k, mk);
3463
3467
3464
- mqka[0 ] += dot (mq[ii/NL][0 ], mk[0 ]);
3465
- mqka[1 ] += dot (mq[ii/NL][1 ], mk[1 ]);
3466
- mqka[2 ] += dot (mq[ii/NL][2 ], mk[2 ]);
3467
- mqka[3 ] += dot (mq[ii/NL][3 ], mk[3 ]);
3468
+ // note: this is less precise than the version below
3469
+ // mqka[0] += dot(mq[ii/NL][0], mk[0]);
3470
+ // mqka[1] += dot(mq[ii/NL][1], mk[1]);
3471
+ // mqka[2] += dot(mq[ii/NL][2], mk[2]);
3472
+ // mqka[3] += dot(mq[ii/NL][3], mk[3]);
3473
+
3474
+ mqka[0 ] += dot ((float4) mq[ii/NL][0 ], (float4) mk[0 ]);
3475
+ mqka[1 ] += dot ((float4) mq[ii/NL][1 ], (float4) mk[1 ]);
3476
+ mqka[2 ] += dot ((float4) mq[ii/NL][2 ], (float4) mk[2 ]);
3477
+ mqka[3 ] += dot ((float4) mq[ii/NL][3 ], (float4) mk[3 ]);
3468
3478
}
3469
3479
3470
3480
qk_t mqk = mqka[0 ] + mqka[1 ] + mqka[2 ] + mqka[3 ];
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
3513
3523
ss[tiisg] = vs;
3514
3524
3515
3525
// O = diag(ms)*O
3516
- #pragma unroll
3526
+ #pragma unroll(D16/NL)
3517
3527
for (short ii = 0 ; ii < D16; ii += NL) {
3518
3528
lo[ii/NL] *= ms;
3519
3529
}
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
3523
3533
3524
3534
// O = O + (Q*K^T)*V
3525
3535
{
3526
- #pragma unroll
3527
3536
for (short cc = 0 ; cc < C/4 ; ++cc) {
3528
3537
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3529
3538
3530
3539
const s4x4_t ms (ss[4 *cc + ty]);
3531
3540
3532
- #pragma unroll
3541
+ #pragma unroll(D16/NL)
3533
3542
for (short ii = 0 ; ii < D16; ii += NL) {
3534
3543
const short i = ii + tx;
3535
3544
0 commit comments