Skip to content

Commit c14f38e

Browse files
abetlenslaren
authored andcommitted
llama: Add attention and final logit soft-capping, update scaling factor to Gemma2 (ggml-org#8197)
* Add attention and final logit softcapping. * fix * Add custom add_ functions * Disable flash attention for Gemma2 * Update src/llama.cpp Co-authored-by: slaren <[email protected]> * Add default value for attention and final logit softcap value * Add custom kq scaling from Gemma2Attention * Remove custom pre attention scaling and use computed value instead. --------- Co-authored-by: slaren <[email protected]>
1 parent 7b2e884 commit c14f38e

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,12 @@ def set_gguf_parameters(self):
23942394
self.gguf_writer.add_key_length(hparams["head_dim"])
23952395
self.gguf_writer.add_value_length(hparams["head_dim"])
23962396
self.gguf_writer.add_file_type(self.ftype)
2397+
self.gguf_writer.add_attn_logit_softcapping(
2398+
self.hparams["attn_logit_softcapping"]
2399+
)
2400+
self.gguf_writer.add_final_logit_softcapping(
2401+
self.hparams["final_logit_softcapping"]
2402+
)
23972403

23982404
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
23992405
del bid # unusem

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class LLM:
5050
POOLING_TYPE = "{arch}.pooling_type"
5151
LOGIT_SCALE = "{arch}.logit_scale"
5252
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
53+
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
54+
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
5355

5456
class Attention:
5557
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,12 @@ def add_clamp_kqv(self, value: float) -> None:
516516
def add_logit_scale(self, value: float) -> None:
517517
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
518518

519+
def add_attn_logit_softcapping(self, value: float) -> None:
520+
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
521+
522+
def add_final_logit_softcapping(self, value: float) -> None:
523+
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
524+
519525
def add_expert_count(self, count: int) -> None:
520526
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
521527

llama.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ enum llm_kv {
326326
LLM_KV_POOLING_TYPE,
327327
LLM_KV_LOGIT_SCALE,
328328
LLM_KV_DECODER_START_TOKEN_ID,
329+
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
330+
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
329331

330332
LLM_KV_ATTENTION_HEAD_COUNT,
331333
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -416,6 +418,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
416418
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
417419
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
418420
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
421+
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
422+
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
419423

420424
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
421425
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2127,6 +2131,9 @@ struct llama_hparams {
21272131
float f_norm_eps;
21282132
float f_norm_rms_eps;
21292133

2134+
float f_attn_logit_softcapping = 50.0f;
2135+
float f_final_logit_softcapping = 30.0f;
2136+
21302137
float rope_attn_factor = 1.0f;
21312138
float rope_freq_base_train;
21322139
float rope_freq_scale_train;
@@ -2143,8 +2150,9 @@ struct llama_hparams {
21432150
float f_max_alibi_bias = 0.0f;
21442151
float f_logit_scale = 0.0f;
21452152

2146-
bool causal_attn = true;
2147-
bool use_alibi = false;
2153+
bool causal_attn = true;
2154+
bool use_alibi = false;
2155+
bool attn_soft_cap = false;
21482156

21492157
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
21502158
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4822,6 +4830,9 @@ static void llm_load_hparams(
48224830
case LLM_ARCH_GEMMA2:
48234831
{
48244832
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4833+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4834+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4835+
hparams.attn_soft_cap = true;
48254836

48264837
switch (hparams.n_layer) {
48274838
case 42: model.type = e_model::MODEL_9B; break;
@@ -7737,6 +7748,12 @@ static struct ggml_tensor * llm_build_kqv(
77377748
kq = ggml_scale(ctx, kq, 30);
77387749
}
77397750

7751+
if (hparams.attn_soft_cap) {
7752+
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7753+
kq = ggml_tanh(ctx, kq);
7754+
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7755+
}
7756+
77407757
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
77417758
cb(kq, "kq_soft_max_ext", il);
77427759

@@ -11197,7 +11214,7 @@ struct llm_build_context {
1119711214
ext_factor, attn_factor, beta_fast, beta_slow);
1119811215
cb(Qcur, "Qcur", il);
1119911216

11200-
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
11217+
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
1120111218
cb(Qcur, "Qcur_scaled", il);
1120211219

1120311220
Kcur = ggml_rope_ext(
@@ -11264,6 +11281,12 @@ struct llm_build_context {
1126411281

1126511282
// lm_head
1126611283
cur = ggml_mul_mat(ctx0, model.output, cur);
11284+
11285+
// final logit soft-capping
11286+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11287+
cur = ggml_tanh(ctx0, cur);
11288+
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11289+
1126711290
cb(cur, "result_output", -1);
1126811291

1126911292
ggml_build_forward_expand(gf, cur);
@@ -20022,6 +20045,12 @@ struct llama_context * llama_new_context_with_model(
2002220045
params.flash_attn = false;
2002320046
}
2002420047

20048+
if (params.flash_attn && model->hparams.attn_soft_cap) {
20049+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
20050+
params.flash_attn = false;
20051+
}
20052+
20053+
2002520054
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2002620055
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2002720056
params.flash_attn = false;

0 commit comments

Comments
 (0)