4
4
5
5
#define FATTN_KQ_STRIDE_TILE_F16 64
6
6
7
- template <int D, int ncols, int nwarps, int parallel_blocks> // D == head size
7
+ template <int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap > // D == head size
8
8
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
9
__launch_bounds__ (nwarps*WARP_SIZE, 1 )
10
10
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
20
20
const float m0,
21
21
const float m1,
22
22
const uint32_t n_head_log2,
23
+ const float logit_softcap,
23
24
const int ne00,
24
25
const int ne01,
25
26
const int ne02,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
44
45
const int ne2,
45
46
const int ne3) {
46
47
#ifdef FP16_AVAILABLE
48
+ // Skip unused kernel variants for faster compilation:
49
+ if (use_logit_softcap && !(D == 128 || D == 256 )) {
50
+ NO_DEVICE_CODE;
51
+ return ;
52
+ }
53
+
47
54
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
48
55
49
56
const int ic0 = (blockIdx .x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
154
161
for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
155
162
const int j_KQ = j_KQ_0 + threadIdx .y ;
156
163
157
- half sum = __low2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164
+ half sum;
165
+ if (use_logit_softcap) {
166
+ const float2 tmp = __half22float2 (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
167
+ sum = logit_softcap * tanhf (tmp.x + tmp.y );
168
+ } else {
169
+ sum = __low2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half (sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
170
+ }
158
171
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
159
172
160
173
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax (kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
270
283
#endif // FP16_AVAILABLE
271
284
}
272
285
273
- template <int cols_per_block, int parallel_blocks>
286
+ template <int cols_per_block, int parallel_blocks, bool use_logit_softcap >
274
287
void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275
288
const ggml_tensor * Q = dst->src [0 ];
276
289
switch (Q->ne [0 ]) {
277
290
case 64 : {
278
291
constexpr int D = 64 ;
279
292
constexpr int nwarps = 8 ;
280
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
293
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap >;
281
294
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
282
295
} break ;
283
296
case 128 : {
284
297
constexpr int D = 128 ;
285
298
constexpr int nwarps = 8 ;
286
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
299
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap >;
287
300
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
288
301
} break ;
289
302
default : {
@@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
296
309
const ggml_tensor * KQV = dst;
297
310
const ggml_tensor * Q = dst->src [0 ];
298
311
299
- const int32_t precision = KQV->op_params [2 ];
312
+ const int32_t precision = KQV->op_params [3 ];
300
313
GGML_ASSERT (precision == GGML_PREC_DEFAULT);
301
314
315
+ float logit_softcap;
316
+ memcpy (&logit_softcap, (const float *) KQV->op_params + 2 , sizeof (float ));
317
+
302
318
if (Q->ne [1 ] <= 16 ) {
303
319
constexpr int cols_per_block = 16 ;
304
320
constexpr int parallel_blocks = 4 ;
305
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
321
+ if (logit_softcap == 0 .0f ) {
322
+ constexpr bool use_logit_softcap = false ;
323
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
324
+ } else {
325
+ constexpr bool use_logit_softcap = true ;
326
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327
+ }
306
328
return ;
307
329
}
308
330
309
331
if (Q->ne [1 ] <= 32 ) {
310
332
constexpr int cols_per_block = 32 ;
311
333
constexpr int parallel_blocks = 4 ;
312
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
334
+ if (logit_softcap == 0 .0f ) {
335
+ constexpr bool use_logit_softcap = false ;
336
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
337
+ } else {
338
+ constexpr bool use_logit_softcap = true ;
339
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340
+ }
313
341
return ;
314
342
}
315
343
316
344
constexpr int cols_per_block = 32 ;
317
345
constexpr int parallel_blocks = 1 ;
318
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
346
+ if (logit_softcap == 0 .0f ) {
347
+ constexpr bool use_logit_softcap = false ;
348
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
349
+ } else {
350
+ constexpr bool use_logit_softcap = true ;
351
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
352
+ }
319
353
}
0 commit comments