Skip to content

Commit e4643c9

Browse files
fix batch size 2-8
1 parent eea1184 commit e4643c9

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml-cuda/fattn.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16(
196196
__syncthreads();
197197

198198
#pragma unroll
199-
for (int j = 0; j < ncols; ++j) {
200-
kqsum[j] = kqsum_shared[j][threadIdx.x];
201-
kqsum[j] = warp_reduce_sum(kqsum[j]);
199+
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
200+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
201+
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
202202

203-
half dst_val = (__low2half(VKQ[j]) + __high2half(VKQ[j]));
203+
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
204204
if (parallel_blocks == 1) {
205-
dst_val /= kqsum[j];
205+
dst_val /= kqsum[j_VKQ];
206206
}
207-
dst[D*gridDim.y*(blockIdx.x*ncols + j) + D*blockIdx.y + tid] = dst_val;
207+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
208+
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
208209
}
209210

210211
if (parallel_blocks == 1 || tid != 0) {

0 commit comments

Comments
 (0)