Skip to content

llama : fix Gemma3 SWA KV cache shift #12373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & freq_base = cparams.rope_freq_base;
const auto & freq_scale = cparams.rope_freq_scale;

const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
Expand Down Expand Up @@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);

float freq_base_l = cparams.rope_freq_base;
float freq_scale_l = cparams.rope_freq_scale;

// TODO: improve
if (model.arch == LLM_ARCH_GEMMA3) {
const bool is_sliding = hparams.is_sliding(il);

freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
}

Comment on lines +540 to +550
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to avoid this special-casing here. It does not look great.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can extend the llama_layer to hold this info in the near future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I've pushed the following version, which should be a bit cleaner: #12374

Will see if there is a better way to do it with the upcoming model implementation refactoring.

ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);

ggml_tensor * k =
Expand All @@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0);

ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);

ggml_build_forward_expand(gf, cur);
}
Expand Down
2 changes: 2 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ struct llama_context {
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const;

llm_graph_result_ptr build_kv_self_shift(
Expand Down
29 changes: 1 addition & 28 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}

// TODO: improve
bool is_sliding = false;

switch (arch) {
case LLM_ARCH_COHERE2:
{
const int32_t sliding_window_pattern = 4;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_GEMMA2:
{
const int32_t sliding_window_pattern = 2;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_GEMMA3:
{
const int32_t sliding_window_pattern = 6;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_PHI3:
{
is_sliding = hparams.n_swa > 0;
} break;
default:
{
is_sliding = false;
}
};
const bool is_sliding = hparams.is_sliding(il);

const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();

Expand Down
8 changes: 8 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}

bool llama_hparams::is_sliding(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}

GGML_ABORT("fatal error");
}
3 changes: 3 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct llama_hparams {
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
Expand Down Expand Up @@ -133,6 +134,8 @@ struct llama_hparams {

// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;

bool is_sliding(uint32_t il) const;
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
Expand Down
21 changes: 9 additions & 12 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA2:
{
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.attn_soft_cap = true;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
hparams.attn_soft_cap = true;

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_2B; break;
Expand All @@ -873,6 +875,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GEMMA3:
{
hparams.n_swa_pattern = 6;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand Down Expand Up @@ -952,6 +956,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_COHERE2:
{
hparams.n_swa_pattern = 4;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
Expand Down Expand Up @@ -7374,12 +7380,8 @@ struct llm_build_gemma3 : public llm_graph_context {
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified(true, true);

// "5-to-1 interleaved attention"
// 5 layers of local attention followed by 1 layer of global attention
static const int sliding_window_pattern = 6;

for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
const bool is_sliding = hparams.is_sliding(il);

const float freq_base_l = is_sliding ? 10000.0f : freq_base;
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
Expand Down Expand Up @@ -7970,13 +7972,8 @@ struct llm_build_cohere2 : public llm_graph_context {

auto * inp_attn = build_attn_inp_kv_unified(true, true);

// sliding window switch pattern
const int32_t sliding_window_pattern = 4;

for (int il = 0; il < n_layer; ++il) {
// three layers sliding window attention (window size 4096) and ROPE
// fourth layer uses global attention without positional embeddings
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
const bool is_sliding = hparams.is_sliding(il);

// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
Expand Down