@@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
106
106
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
107
107
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
108
108
// Calculate KQ tile and keep track of new maximum KQ values:
109
- half kqmax_new[ncols];
109
+
110
+ // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
111
+ // see https://github.com/ggerganov/llama.cpp/pull/7061 .
112
+ // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
113
+ half kqmax_new = kqmax[0 ];
114
+ half kqmax_new_arr[ncols];
110
115
#pragma unroll
111
116
for (int j = 0 ; j < ncols; ++j) {
112
- kqmax_new [j] = kqmax[j];
117
+ kqmax_new_arr [j] = kqmax[j];
113
118
}
114
119
115
120
#pragma unroll
@@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
137
142
sum2[j] = warp_reduce_sum (sum2[j]);
138
143
half sum = __low2half (sum2[j]) + __high2half (sum2[j]);
139
144
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
140
- kqmax_new[j] = ggml_cuda_hmax (kqmax_new[j], sum);
145
+
146
+ if (ncols == 1 ) {
147
+ kqmax_new = ggml_cuda_hmax (kqmax_new, sum);
148
+ } else {
149
+ kqmax_new_arr[j] = ggml_cuda_hmax (kqmax_new_arr[j], sum);
150
+ }
151
+
141
152
if (threadIdx .x == 0 ) {
142
153
KQ[j*D + i_KQ] = sum;
143
154
}
@@ -146,21 +157,23 @@ static __global__ void flash_attn_vec_ext_f16(
146
157
147
158
#pragma unroll
148
159
for (int j = 0 ; j < ncols; ++j) {
149
- kqmax_new[j] = warp_reduce_max (kqmax_new[j]);
160
+ half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
161
+
162
+ kqmax_new_j = warp_reduce_max (kqmax_new_j);
150
163
if (threadIdx .x == 0 ) {
151
- kqmax_shared[j][threadIdx .y ] = kqmax_new[j] ;
164
+ kqmax_shared[j][threadIdx .y ] = kqmax_new_j ;
152
165
}
153
166
}
154
167
155
168
__syncthreads ();
156
169
157
170
#pragma unroll
158
171
for (int j = 0 ; j < ncols; ++j) {
159
- kqmax_new[j] = kqmax_shared[j][threadIdx .x ];
160
- kqmax_new[j] = warp_reduce_max (kqmax_new[j] );
172
+ half kqmax_new_j = kqmax_shared[j][threadIdx .x ];
173
+ kqmax_new_j = warp_reduce_max (kqmax_new_j );
161
174
162
- const half KQ_max_scale = hexp (kqmax[j] - kqmax_new[j] );
163
- kqmax[j] = kqmax_new[j] ;
175
+ const half KQ_max_scale = hexp (kqmax[j] - kqmax_new_j );
176
+ kqmax[j] = kqmax_new_j ;
164
177
165
178
const half val = hexp (KQ[j*D + tid] - kqmax[j]);
166
179
kqsum[j] = kqsum[j]*KQ_max_scale + val;
0 commit comments