@@ -302,6 +302,8 @@ enum llm_kv {
302
302
LLM_KV_POOLING_TYPE,
303
303
LLM_KV_LOGIT_SCALE,
304
304
LLM_KV_DECODER_START_TOKEN_ID,
305
+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
306
+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
305
307
306
308
LLM_KV_ATTENTION_HEAD_COUNT,
307
309
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
392
394
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
393
395
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
394
396
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
397
+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
398
+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
395
399
396
400
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
397
401
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2099,6 +2103,9 @@ struct llama_hparams {
2099
2103
float f_norm_eps;
2100
2104
float f_norm_rms_eps;
2101
2105
2106
+ float f_attn_logit_softcapping = 50.0f;
2107
+ float f_final_logit_softcapping = 30.0f;
2108
+
2102
2109
float rope_attn_factor = 1.0f;
2103
2110
float rope_freq_base_train;
2104
2111
float rope_freq_scale_train;
@@ -2115,8 +2122,9 @@ struct llama_hparams {
2115
2122
float f_max_alibi_bias = 0.0f;
2116
2123
float f_logit_scale = 0.0f;
2117
2124
2118
- bool causal_attn = true;
2119
- bool use_alibi = false;
2125
+ bool causal_attn = true;
2126
+ bool use_alibi = false;
2127
+ bool attn_soft_cap = false;
2120
2128
2121
2129
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
2122
2130
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4702,6 +4710,9 @@ static void llm_load_hparams(
4702
4710
case LLM_ARCH_GEMMA2:
4703
4711
{
4704
4712
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4713
+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4714
+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4715
+ hparams.attn_soft_cap = true;
4705
4716
4706
4717
switch (hparams.n_layer) {
4707
4718
case 42: model.type = e_model::MODEL_9B; break;
@@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
7579
7590
kq = ggml_scale(ctx, kq, 30);
7580
7591
}
7581
7592
7593
+ if (hparams.attn_soft_cap) {
7594
+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7595
+ kq = ggml_tanh(ctx, kq);
7596
+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7597
+ }
7598
+
7582
7599
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
7583
7600
cb(kq, "kq_soft_max_ext", il);
7584
7601
@@ -11039,7 +11056,7 @@ struct llm_build_context {
11039
11056
ext_factor, attn_factor, beta_fast, beta_slow);
11040
11057
cb(Qcur, "Qcur", il);
11041
11058
11042
- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k )));
11059
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head )));
11043
11060
cb(Qcur, "Qcur_scaled", il);
11044
11061
11045
11062
Kcur = ggml_rope_ext(
@@ -11106,6 +11123,12 @@ struct llm_build_context {
11106
11123
11107
11124
// lm_head
11108
11125
cur = ggml_mul_mat(ctx0, model.output, cur);
11126
+
11127
+ // final logit soft-capping
11128
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11129
+ cur = ggml_tanh(ctx0, cur);
11130
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11131
+
11109
11132
cb(cur, "result_output", -1);
11110
11133
11111
11134
ggml_build_forward_expand(gf, cur);
@@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
17379
17402
params.flash_attn = false;
17380
17403
}
17381
17404
17405
+ if (params.flash_attn && model->hparams.attn_soft_cap) {
17406
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
17407
+ params.flash_attn = false;
17408
+ }
17409
+
17410
+
17382
17411
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
17383
17412
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
17384
17413
params.flash_attn = false;
0 commit comments