Skip to content

Commit 2272765

Browse files
fix performance regression
1 parent fa81c3a commit 2272765

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

ggml-cuda/fattn.cu

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
106106
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
107107
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
108108
// Calculate KQ tile and keep track of new maximum KQ values:
109-
half kqmax_new[ncols];
109+
110+
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
111+
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
112+
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
113+
half kqmax_new = kqmax[0];
114+
half kqmax_new_arr[ncols];
110115
#pragma unroll
111116
for (int j = 0; j < ncols; ++j) {
112-
kqmax_new[j] = kqmax[j];
117+
kqmax_new_arr[j] = kqmax[j];
113118
}
114119

115120
#pragma unroll
@@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
137142
sum2[j] = warp_reduce_sum(sum2[j]);
138143
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
139144
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
140-
kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum);
145+
146+
if (ncols == 1) {
147+
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
148+
} else {
149+
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
150+
}
151+
141152
if (threadIdx.x == 0) {
142153
KQ[j*D + i_KQ] = sum;
143154
}
@@ -146,21 +157,23 @@ static __global__ void flash_attn_vec_ext_f16(
146157

147158
#pragma unroll
148159
for (int j = 0; j < ncols; ++j) {
149-
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
160+
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
161+
162+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
150163
if (threadIdx.x == 0) {
151-
kqmax_shared[j][threadIdx.y] = kqmax_new[j];
164+
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
152165
}
153166
}
154167

155168
__syncthreads();
156169

157170
#pragma unroll
158171
for (int j = 0; j < ncols; ++j) {
159-
kqmax_new[j] = kqmax_shared[j][threadIdx.x];
160-
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
172+
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
173+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
161174

162-
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]);
163-
kqmax[j] = kqmax_new[j];
175+
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
176+
kqmax[j] = kqmax_new_j;
164177

165178
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
166179
kqsum[j] = kqsum[j]*KQ_max_scale + val;

0 commit comments

Comments
 (0)