@@ -6436,7 +6436,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
6436
6436
struct ggml_tensor * k,
6437
6437
struct ggml_tensor * v,
6438
6438
struct ggml_tensor * mask,
6439
- float scale) {
6439
+ float scale,
6440
+ float max_bias) {
6440
6441
GGML_ASSERT(ggml_can_mul_mat(k, q));
6441
6442
// TODO: check if vT can be multiplied by (k*qT)
6442
6443
if (mask) {
@@ -6458,7 +6459,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6458
6459
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6459
6460
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6460
6461
6461
- float params[] = { scale };
6462
+ float params[] = { scale, max_bias };
6462
6463
ggml_set_op_params(result, params, sizeof(params));
6463
6464
6464
6465
result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6478,7 +6479,7 @@ void ggml_flash_attn_ext_set_prec(
6478
6479
6479
6480
const int32_t prec_i32 = (int32_t) prec;
6480
6481
6481
- ggml_set_op_params_i32(a, 1 , prec_i32); // scale is on first pos
6482
+ ggml_set_op_params_i32(a, 2 , prec_i32); // scale is on first pos, max_bias on second
6482
6483
}
6483
6484
6484
6485
// ggml_flash_ff
@@ -15524,8 +15525,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15524
15525
const int ir0 = dr*ith;
15525
15526
const int ir1 = MIN(ir0 + dr, nr);
15526
15527
15527
- float scale = 1.0f;
15528
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15528
+ float scale = 1.0f;
15529
+ float max_bias = 0.0f;
15530
+
15531
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15532
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15533
+
15534
+ const uint32_t n_head_kv = neq2;
15535
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
15536
+
15537
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15538
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15529
15539
15530
15540
// loop over n_batch and n_head
15531
15541
for (int ir = ir0; ir < ir1; ++ir) {
@@ -15534,6 +15544,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15534
15544
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15535
15545
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15536
15546
15547
+ const int h = iq2; // head
15548
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15549
+
15537
15550
float S = 0.0f;
15538
15551
float M = -INFINITY;
15539
15552
@@ -15557,7 +15570,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15557
15570
// loop over n_kv and n_head_kv
15558
15571
// ref: https://arxiv.org/pdf/2112.05682.pdf
15559
15572
for (int64_t ic = 0; ic < nek1; ++ic) {
15560
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15573
+ const float mv = mp ? slope* GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15561
15574
if (mv == -INFINITY) {
15562
15575
continue;
15563
15576
}
@@ -15628,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext(
15628
15641
const struct ggml_tensor * v,
15629
15642
const struct ggml_tensor * mask,
15630
15643
struct ggml_tensor * dst) {
15631
- switch (dst->op_params[1 ]) {
15644
+ switch (dst->op_params[2 ]) {
15632
15645
case GGML_PREC_DEFAULT:
15633
15646
case GGML_PREC_F32:
15634
15647
{
0 commit comments