Skip to content

Commit a4c7cf7

Browse files
committed
ggml : ggml_flash_attn_ext() support ALiBi (CPU)
1 parent d0592d4 commit a4c7cf7

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

ggml.c

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6436,7 +6436,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
64366436
struct ggml_tensor * k,
64376437
struct ggml_tensor * v,
64386438
struct ggml_tensor * mask,
6439-
float scale) {
6439+
float scale,
6440+
float max_bias) {
64406441
GGML_ASSERT(ggml_can_mul_mat(k, q));
64416442
// TODO: check if vT can be multiplied by (k*qT)
64426443
if (mask) {
@@ -6458,7 +6459,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
64586459
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
64596460
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
64606461

6461-
float params[] = { scale };
6462+
float params[] = { scale, max_bias };
64626463
ggml_set_op_params(result, params, sizeof(params));
64636464

64646465
result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6478,7 +6479,7 @@ void ggml_flash_attn_ext_set_prec(
64786479

64796480
const int32_t prec_i32 = (int32_t) prec;
64806481

6481-
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6482+
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
64826483
}
64836484

64846485
// ggml_flash_ff
@@ -15524,8 +15525,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1552415525
const int ir0 = dr*ith;
1552515526
const int ir1 = MIN(ir0 + dr, nr);
1552615527

15527-
float scale = 1.0f;
15528-
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15528+
float scale = 1.0f;
15529+
float max_bias = 0.0f;
15530+
15531+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15532+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15533+
15534+
const uint32_t n_head_kv = neq2;
15535+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
15536+
15537+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15538+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1552915539

1553015540
// loop over n_batch and n_head
1553115541
for (int ir = ir0; ir < ir1; ++ir) {
@@ -15534,6 +15544,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1553415544
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
1553515545
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
1553615546

15547+
const int h = iq2; // head
15548+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15549+
1553715550
float S = 0.0f;
1553815551
float M = -INFINITY;
1553915552

@@ -15557,7 +15570,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1555715570
// loop over n_kv and n_head_kv
1555815571
// ref: https://arxiv.org/pdf/2112.05682.pdf
1555915572
for (int64_t ic = 0; ic < nek1; ++ic) {
15560-
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15573+
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
1556115574
if (mv == -INFINITY) {
1556215575
continue;
1556315576
}
@@ -15628,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext(
1562815641
const struct ggml_tensor * v,
1562915642
const struct ggml_tensor * mask,
1563015643
struct ggml_tensor * dst) {
15631-
switch (dst->op_params[1]) {
15644+
switch (dst->op_params[2]) {
1563215645
case GGML_PREC_DEFAULT:
1563315646
case GGML_PREC_F32:
1563415647
{

ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,8 @@ extern "C" {
17311731
struct ggml_tensor * k,
17321732
struct ggml_tensor * v,
17331733
struct ggml_tensor * mask,
1734-
float scale);
1734+
float scale,
1735+
float max_bias);
17351736

17361737
GGML_API void ggml_flash_attn_ext_set_prec(
17371738
struct ggml_tensor * a,

llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6537,7 +6537,7 @@ static struct ggml_tensor * llm_build_kqv(
65376537
0);
65386538
cb(v, "v", il);
65396539

6540-
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
6540+
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
65416541

65426542
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
65436543
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

0 commit comments

Comments
 (0)