Skip to content

Commit e11bd85

Browse files
CPU/CUDA: Gemma 2 FlashAttention support (ggml-org#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
1 parent 8f824ff commit e11bd85

12 files changed

+319
-79
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1760,7 +1760,8 @@ extern "C" {
17601760
struct ggml_tensor * v,
17611761
struct ggml_tensor * mask,
17621762
float scale,
1763-
float max_bias);
1763+
float max_bias,
1764+
float logit_softcap);
17641765

17651766
GGML_API void ggml_flash_attn_ext_set_prec(
17661767
struct ggml_tensor * a,

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
2222
const float m0,
2323
const float m1,
2424
const uint32_t n_head_log2,
25+
const float logit_softcap,
2526
const int ne00,
2627
const int ne01,
2728
const int ne02,
@@ -657,11 +658,17 @@ void launch_fattn(
657658
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
658659
const int shmem = 0;
659660

660-
float scale = 1.0f;
661-
float max_bias = 0.0f;
661+
float scale = 1.0f;
662+
float max_bias = 0.0f;
663+
float logit_softcap = 0.0f;
662664

663-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
664-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
665+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
666+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
667+
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
668+
669+
if (logit_softcap != 0.0f) {
670+
scale /= logit_softcap;
671+
}
665672

666673
const uint32_t n_head = Q->ne[2];
667674
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -675,7 +682,7 @@ void launch_fattn(
675682
V_data,
676683
mask ? ((const char *) mask->data) : nullptr,
677684
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
678-
scale, max_bias, m0, m1, n_head_log2,
685+
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
679686
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
680687
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
681688
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#define FATTN_KQ_STRIDE_TILE_F16 64
66

7-
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
7+
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
99
__launch_bounds__(nwarps*WARP_SIZE, 1)
1010
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
2020
const float m0,
2121
const float m1,
2222
const uint32_t n_head_log2,
23+
const float logit_softcap,
2324
const int ne00,
2425
const int ne01,
2526
const int ne02,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
4445
const int ne2,
4546
const int ne3) {
4647
#ifdef FP16_AVAILABLE
48+
// Skip unused kernel variants for faster compilation:
49+
if (use_logit_softcap && !(D == 128 || D == 256)) {
50+
NO_DEVICE_CODE;
51+
return;
52+
}
53+
4754
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
4855

4956
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
154161
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
155162
const int j_KQ = j_KQ_0 + threadIdx.y;
156163

157-
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164+
half sum;
165+
if (use_logit_softcap) {
166+
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
167+
sum = logit_softcap * tanhf(tmp.x + tmp.y);
168+
} else {
169+
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
170+
}
158171
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
159172

160173
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
270283
#endif // FP16_AVAILABLE
271284
}
272285

273-
template <int cols_per_block, int parallel_blocks>
286+
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
274287
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275288
const ggml_tensor * Q = dst->src[0];
276289
switch (Q->ne[0]) {
277290
case 64: {
278291
constexpr int D = 64;
279292
constexpr int nwarps = 8;
280-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
293+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
281294
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
282295
} break;
283296
case 128: {
284297
constexpr int D = 128;
285298
constexpr int nwarps = 8;
286-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
299+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
287300
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
288301
} break;
289302
default: {
@@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
296309
const ggml_tensor * KQV = dst;
297310
const ggml_tensor * Q = dst->src[0];
298311

299-
const int32_t precision = KQV->op_params[2];
312+
const int32_t precision = KQV->op_params[3];
300313
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
301314

315+
float logit_softcap;
316+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
317+
302318
if (Q->ne[1] <= 16) {
303319
constexpr int cols_per_block = 16;
304320
constexpr int parallel_blocks = 4;
305-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
321+
if (logit_softcap == 0.0f) {
322+
constexpr bool use_logit_softcap = false;
323+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
324+
} else {
325+
constexpr bool use_logit_softcap = true;
326+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327+
}
306328
return;
307329
}
308330

309331
if (Q->ne[1] <= 32) {
310332
constexpr int cols_per_block = 32;
311333
constexpr int parallel_blocks = 4;
312-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
334+
if (logit_softcap == 0.0f) {
335+
constexpr bool use_logit_softcap = false;
336+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
337+
} else {
338+
constexpr bool use_logit_softcap = true;
339+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340+
}
313341
return;
314342
}
315343

316344
constexpr int cols_per_block = 32;
317345
constexpr int parallel_blocks = 1;
318-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
346+
if (logit_softcap == 0.0f) {
347+
constexpr bool use_logit_softcap = false;
348+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
349+
} else {
350+
constexpr bool use_logit_softcap = true;
351+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
352+
}
319353
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#define FATTN_KQ_STRIDE_TILE_F32 32
66

7-
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
7+
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
99
__launch_bounds__(nwarps*WARP_SIZE, 1)
1010
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
2020
const float m0,
2121
const float m1,
2222
const uint32_t n_head_log2,
23+
const float logit_softcap,
2324
const int ne00,
2425
const int ne01,
2526
const int ne02,
@@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
4344
const int ne1,
4445
const int ne2,
4546
const int ne3) {
47+
// Skip unused kernel variants for faster compilation:
48+
if (use_logit_softcap && !(D == 128 || D == 256)) {
49+
NO_DEVICE_CODE;
50+
return;
51+
}
52+
4653
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
4754

4855
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
151158
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
152159
const int j_KQ = j_KQ_0 + threadIdx.y;
153160

161+
if (use_logit_softcap) {
162+
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
163+
}
164+
154165
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
155166

156167
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
@@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
267278
}
268279
}
269280

270-
template <int cols_per_block, int parallel_blocks>
281+
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
271282
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272283
const ggml_tensor * Q = dst->src[0];
273284
switch (Q->ne[0]) {
274285
case 64: {
275286
constexpr int D = 64;
276287
constexpr int nwarps = 8;
277-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
288+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
278289
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
279290
} break;
280291
case 128: {
281292
constexpr int D = 128;
282293
constexpr int nwarps = 8;
283-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
294+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
284295
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
285296
} break;
286297
default: {
@@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
290301
}
291302

292303
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
304+
const ggml_tensor * KQV = dst;
293305
const ggml_tensor * Q = dst->src[0];
294306

307+
float logit_softcap;
308+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
309+
295310
if (Q->ne[1] <= 16) {
296311
constexpr int cols_per_block = 16;
297312
constexpr int parallel_blocks = 4;
298-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
313+
if (logit_softcap == 0.0f) {
314+
constexpr bool use_logit_softcap = false;
315+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
316+
} else {
317+
constexpr bool use_logit_softcap = true;
318+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
319+
}
299320
return;
300321
}
301322

302323
if (Q->ne[1] <= 32) {
303324
constexpr int cols_per_block = 32;
304325
constexpr int parallel_blocks = 4;
305-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
326+
if (logit_softcap == 0.0f) {
327+
constexpr bool use_logit_softcap = false;
328+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
329+
} else {
330+
constexpr bool use_logit_softcap = true;
331+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
332+
}
306333
return;
307334
}
308335

309336
constexpr int cols_per_block = 32;
310337
constexpr int parallel_blocks = 1;
311-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
338+
if (logit_softcap == 0.0f) {
339+
constexpr bool use_logit_softcap = false;
340+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
341+
} else {
342+
constexpr bool use_logit_softcap = true;
343+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344+
}
312345
}

0 commit comments

Comments
 (0)