@@ -317,6 +317,7 @@ enum llm_kv {
317
317
LLM_KV_ATTENTION_Q_LORA_RANK,
318
318
LLM_KV_ATTENTION_KV_LORA_RANK,
319
319
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
320
+ LLM_KV_ATTENTION_SLIDING_WINDOW,
320
321
321
322
LLM_KV_ROPE_DIMENSION_COUNT,
322
323
LLM_KV_ROPE_FREQ_BASE,
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
409
410
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
410
411
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
411
412
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
413
+ { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
412
414
413
415
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
414
416
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -2085,6 +2087,7 @@ struct llama_hparams {
2085
2087
uint32_t n_head_kv;
2086
2088
uint32_t n_layer;
2087
2089
uint32_t n_rot;
2090
+ uint32_t n_swa = 0; // sliding window attention (SWA)
2088
2091
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
2089
2092
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
2090
2093
uint32_t n_ff;
@@ -2139,6 +2142,7 @@ struct llama_hparams {
2139
2142
if (this->n_head_kv != other.n_head_kv) return true;
2140
2143
if (this->n_layer != other.n_layer) return true;
2141
2144
if (this->n_rot != other.n_rot) return true;
2145
+ if (this->n_swa != other.n_swa) return true;
2142
2146
if (this->n_embd_head_k != other.n_embd_head_k) return true;
2143
2147
if (this->n_embd_head_v != other.n_embd_head_v) return true;
2144
2148
if (this->n_ff != other.n_ff) return true;
@@ -2649,17 +2653,18 @@ struct llama_context {
2649
2653
void * abort_callback_data = nullptr;
2650
2654
2651
2655
// input tensors
2652
- struct ggml_tensor * inp_tokens; // I32 [n_batch]
2653
- struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2654
- struct ggml_tensor * inp_pos; // I32 [n_batch]
2655
- struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2656
- struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2657
- struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2658
- struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2659
- struct ggml_tensor * inp_cls; // I32 [n_batch]
2660
- struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2661
- struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2662
- struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2656
+ struct ggml_tensor * inp_tokens; // I32 [n_batch]
2657
+ struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2658
+ struct ggml_tensor * inp_pos; // I32 [n_batch]
2659
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2660
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2661
+ struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
2662
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2663
+ struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2664
+ struct ggml_tensor * inp_cls; // I32 [n_batch]
2665
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2666
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2667
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2663
2668
2664
2669
// control vectors
2665
2670
struct llama_control_vector cvec;
@@ -4709,6 +4714,8 @@ static void llm_load_hparams(
4709
4714
} break;
4710
4715
case LLM_ARCH_GEMMA2:
4711
4716
{
4717
+ hparams.n_swa = 4096; // default value of gemma 2
4718
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
4712
4719
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4713
4720
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4714
4721
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -5419,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
5419
5426
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
5420
5427
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
5421
5428
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
5429
+ LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
5422
5430
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
5423
5431
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
5424
5432
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
@@ -7775,17 +7783,18 @@ struct llm_build_context {
7775
7783
7776
7784
ctx0 = ggml_init(params);
7777
7785
7778
- lctx.inp_tokens = nullptr;
7779
- lctx.inp_embd = nullptr;
7780
- lctx.inp_pos = nullptr;
7781
- lctx.inp_out_ids = nullptr;
7782
- lctx.inp_KQ_mask = nullptr;
7783
- lctx.inp_K_shift = nullptr;
7784
- lctx.inp_mean = nullptr;
7785
- lctx.inp_cls = nullptr;
7786
- lctx.inp_s_copy = nullptr;
7787
- lctx.inp_s_mask = nullptr;
7788
- lctx.inp_s_seq = nullptr;
7786
+ lctx.inp_tokens = nullptr;
7787
+ lctx.inp_embd = nullptr;
7788
+ lctx.inp_pos = nullptr;
7789
+ lctx.inp_out_ids = nullptr;
7790
+ lctx.inp_KQ_mask = nullptr;
7791
+ lctx.inp_KQ_mask_swa = nullptr;
7792
+ lctx.inp_K_shift = nullptr;
7793
+ lctx.inp_mean = nullptr;
7794
+ lctx.inp_cls = nullptr;
7795
+ lctx.inp_s_copy = nullptr;
7796
+ lctx.inp_s_mask = nullptr;
7797
+ lctx.inp_s_seq = nullptr;
7789
7798
}
7790
7799
7791
7800
void free() {
@@ -7804,7 +7813,6 @@ struct llm_build_context {
7804
7813
cb(lctx.inp_K_shift, "K_shift", -1);
7805
7814
ggml_set_input(lctx.inp_K_shift);
7806
7815
7807
-
7808
7816
for (int il = 0; il < n_layer; ++il) {
7809
7817
struct ggml_tensor * rope_factors = build_rope_factors(il);
7810
7818
struct ggml_tensor * tmp =
@@ -7939,16 +7947,27 @@ struct llm_build_context {
7939
7947
}
7940
7948
7941
7949
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7942
- if (causal) {
7943
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7944
- } else {
7945
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7946
- }
7950
+ lctx.inp_KQ_mask = causal
7951
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7952
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7947
7953
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
7948
7954
ggml_set_input(lctx.inp_KQ_mask);
7955
+
7949
7956
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
7950
7957
}
7951
7958
7959
+ struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
7960
+ GGML_ASSERT(hparams.n_swa > 0);
7961
+
7962
+ lctx.inp_KQ_mask_swa = causal
7963
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7964
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7965
+ cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
7966
+ ggml_set_input(lctx.inp_KQ_mask_swa);
7967
+
7968
+ return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
7969
+ }
7970
+
7952
7971
struct ggml_tensor * build_inp_mean() {
7953
7972
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
7954
7973
cb(lctx.inp_mean, "inp_mean", -1);
@@ -11029,9 +11048,14 @@ struct llm_build_context {
11029
11048
struct ggml_tensor * inp_pos = build_inp_pos();
11030
11049
11031
11050
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11032
- struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11051
+ // gemma 2 requires different mask for layers using sliding window (SWA)
11052
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11053
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
11033
11054
11034
11055
for (int il = 0; il < n_layer; ++il) {
11056
+ // (il % 2) layers use SWA
11057
+ struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11058
+
11035
11059
// norm
11036
11060
cur = llm_build_norm(ctx0, inpL, hparams,
11037
11061
model.layers[il].attn_norm, NULL,
@@ -11067,7 +11091,7 @@ struct llm_build_context {
11067
11091
11068
11092
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
11069
11093
model.layers[il].wo, NULL,
11070
- Kcur, Vcur, Qcur, KQ_mask , n_tokens, kv_head, n_kv, 1.0f, cb, il);
11094
+ Kcur, Vcur, Qcur, KQ_mask_l , n_tokens, kv_head, n_kv, 1.0f, cb, il);
11071
11095
}
11072
11096
11073
11097
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12670,7 +12694,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12670
12694
12671
12695
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
12672
12696
12673
- float * data = (float *) lctx.inp_KQ_mask->data;
12697
+ float * data = (float *) lctx.inp_KQ_mask->data;
12698
+ float * data_swa = nullptr;
12699
+
12700
+ if (lctx.inp_KQ_mask_swa) {
12701
+ data_swa = (float *) lctx.inp_KQ_mask_swa->data;
12702
+ }
12674
12703
12675
12704
// For causal attention, use only the previous KV cells
12676
12705
// of the correct sequence for each token of the batch.
@@ -12692,6 +12721,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12692
12721
}
12693
12722
}
12694
12723
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12724
+
12725
+ // may need to cut off old tokens for sliding window
12726
+ if (data_swa) {
12727
+ if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
12728
+ f = -INFINITY;
12729
+ }
12730
+ data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12731
+ }
12695
12732
}
12696
12733
}
12697
12734
0 commit comments