File tree 1 file changed +7
-6
lines changed 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16(
196
196
__syncthreads ();
197
197
198
198
#pragma unroll
199
- for (int j = 0 ; j < ncols; ++j ) {
200
- kqsum[j ] = kqsum_shared[j ][threadIdx .x ];
201
- kqsum[j ] = warp_reduce_sum (kqsum[j ]);
199
+ for (int j_VKQ = 0 ; j_VKQ < ncols; ++j_VKQ ) {
200
+ kqsum[j_VKQ ] = kqsum_shared[j_VKQ ][threadIdx .x ];
201
+ kqsum[j_VKQ ] = warp_reduce_sum (kqsum[j_VKQ ]);
202
202
203
- half dst_val = (__low2half (VKQ[j ]) + __high2half (VKQ[j ]));
203
+ half dst_val = (__low2half (VKQ[j_VKQ ]) + __high2half (VKQ[j_VKQ ]));
204
204
if (parallel_blocks == 1 ) {
205
- dst_val /= kqsum[j ];
205
+ dst_val /= kqsum[j_VKQ ];
206
206
}
207
- dst[D*gridDim .y *(blockIdx .x *ncols + j) + D*blockIdx .y + tid] = dst_val;
207
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
208
+ dst[j_dst*D*gridDim .y + D*blockIdx .y + tid] = dst_val;
208
209
}
209
210
210
211
if (parallel_blocks == 1 || tid != 0 ) {
You can’t perform that action at this time.
0 commit comments