Skip to content

Commit 49122a8

Browse files
ngxsonarlo-phoenixggerganov
authored
gemma2: add sliding window mask (#8227)
* gemma2: add sliding window mask * fix data_swa uninitialized * better naming * add co-author Co-authored-by: Arlo Phoenix <[email protected]> * replace list with single tensor * update * llama : minor styling * convert : add sanity check for query_pre_attn_scalar * fix small typo in README --------- Co-authored-by: Arlo Phoenix <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 0ddeff1 commit 49122a8

File tree

5 files changed

+79
-32
lines changed

5 files changed

+79
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
218218
**Tools:**
219219

220220
- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
221-
[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
221+
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
222222

223223
---
224224

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,12 @@ def set_gguf_parameters(self):
23692369
self.gguf_writer.add_final_logit_softcapping(
23702370
self.hparams["final_logit_softcapping"]
23712371
)
2372+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
2373+
2374+
# sanity check
2375+
attn_scalar = self.hparams["query_pre_attn_scalar"]
2376+
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
2377+
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
23722378

23732379
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
23742380
del bid # unusem

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class Attention:
6666
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
6767
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
6868
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
69+
SLIDING_WINDOW = "{arch}.attention.sliding_window"
6970

7071
class Rope:
7172
DIMENSION_COUNT = "{arch}.rope.dimension_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None:
552552
def add_relative_attn_buckets_count(self, value: int) -> None:
553553
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
554554

555+
def add_sliding_window(self, value: int) -> None:
556+
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
557+
555558
def add_pooling_type(self, value: PoolingType) -> None:
556559
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
557560

src/llama.cpp

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ enum llm_kv {
317317
LLM_KV_ATTENTION_Q_LORA_RANK,
318318
LLM_KV_ATTENTION_KV_LORA_RANK,
319319
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
320+
LLM_KV_ATTENTION_SLIDING_WINDOW,
320321

321322
LLM_KV_ROPE_DIMENSION_COUNT,
322323
LLM_KV_ROPE_FREQ_BASE,
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
409410
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
410411
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
411412
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
413+
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
412414

413415
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
414416
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -2085,6 +2087,7 @@ struct llama_hparams {
20852087
uint32_t n_head_kv;
20862088
uint32_t n_layer;
20872089
uint32_t n_rot;
2090+
uint32_t n_swa = 0; // sliding window attention (SWA)
20882091
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
20892092
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
20902093
uint32_t n_ff;
@@ -2139,6 +2142,7 @@ struct llama_hparams {
21392142
if (this->n_head_kv != other.n_head_kv) return true;
21402143
if (this->n_layer != other.n_layer) return true;
21412144
if (this->n_rot != other.n_rot) return true;
2145+
if (this->n_swa != other.n_swa) return true;
21422146
if (this->n_embd_head_k != other.n_embd_head_k) return true;
21432147
if (this->n_embd_head_v != other.n_embd_head_v) return true;
21442148
if (this->n_ff != other.n_ff) return true;
@@ -2649,17 +2653,18 @@ struct llama_context {
26492653
void * abort_callback_data = nullptr;
26502654

26512655
// input tensors
2652-
struct ggml_tensor * inp_tokens; // I32 [n_batch]
2653-
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2654-
struct ggml_tensor * inp_pos; // I32 [n_batch]
2655-
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2656-
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2657-
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2658-
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2659-
struct ggml_tensor * inp_cls; // I32 [n_batch]
2660-
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2661-
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2662-
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2656+
struct ggml_tensor * inp_tokens; // I32 [n_batch]
2657+
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2658+
struct ggml_tensor * inp_pos; // I32 [n_batch]
2659+
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2660+
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2661+
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
2662+
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2663+
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2664+
struct ggml_tensor * inp_cls; // I32 [n_batch]
2665+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2666+
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2667+
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
26632668

26642669
// control vectors
26652670
struct llama_control_vector cvec;
@@ -4709,6 +4714,8 @@ static void llm_load_hparams(
47094714
} break;
47104715
case LLM_ARCH_GEMMA2:
47114716
{
4717+
hparams.n_swa = 4096; // default value of gemma 2
4718+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
47124719
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
47134720
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
47144721
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -5419,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
54195426
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
54205427
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
54215428
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
5429+
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
54225430
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
54235431
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
54245432
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
@@ -7775,17 +7783,18 @@ struct llm_build_context {
77757783

77767784
ctx0 = ggml_init(params);
77777785

7778-
lctx.inp_tokens = nullptr;
7779-
lctx.inp_embd = nullptr;
7780-
lctx.inp_pos = nullptr;
7781-
lctx.inp_out_ids = nullptr;
7782-
lctx.inp_KQ_mask = nullptr;
7783-
lctx.inp_K_shift = nullptr;
7784-
lctx.inp_mean = nullptr;
7785-
lctx.inp_cls = nullptr;
7786-
lctx.inp_s_copy = nullptr;
7787-
lctx.inp_s_mask = nullptr;
7788-
lctx.inp_s_seq = nullptr;
7786+
lctx.inp_tokens = nullptr;
7787+
lctx.inp_embd = nullptr;
7788+
lctx.inp_pos = nullptr;
7789+
lctx.inp_out_ids = nullptr;
7790+
lctx.inp_KQ_mask = nullptr;
7791+
lctx.inp_KQ_mask_swa = nullptr;
7792+
lctx.inp_K_shift = nullptr;
7793+
lctx.inp_mean = nullptr;
7794+
lctx.inp_cls = nullptr;
7795+
lctx.inp_s_copy = nullptr;
7796+
lctx.inp_s_mask = nullptr;
7797+
lctx.inp_s_seq = nullptr;
77897798
}
77907799

77917800
void free() {
@@ -7804,7 +7813,6 @@ struct llm_build_context {
78047813
cb(lctx.inp_K_shift, "K_shift", -1);
78057814
ggml_set_input(lctx.inp_K_shift);
78067815

7807-
78087816
for (int il = 0; il < n_layer; ++il) {
78097817
struct ggml_tensor * rope_factors = build_rope_factors(il);
78107818
struct ggml_tensor * tmp =
@@ -7939,16 +7947,27 @@ struct llm_build_context {
79397947
}
79407948

79417949
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7942-
if (causal) {
7943-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7944-
} else {
7945-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7946-
}
7950+
lctx.inp_KQ_mask = causal
7951+
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7952+
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
79477953
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
79487954
ggml_set_input(lctx.inp_KQ_mask);
7955+
79497956
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
79507957
}
79517958

7959+
struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
7960+
GGML_ASSERT(hparams.n_swa > 0);
7961+
7962+
lctx.inp_KQ_mask_swa = causal
7963+
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7964+
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7965+
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
7966+
ggml_set_input(lctx.inp_KQ_mask_swa);
7967+
7968+
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
7969+
}
7970+
79527971
struct ggml_tensor * build_inp_mean() {
79537972
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
79547973
cb(lctx.inp_mean, "inp_mean", -1);
@@ -11029,9 +11048,14 @@ struct llm_build_context {
1102911048
struct ggml_tensor * inp_pos = build_inp_pos();
1103011049

1103111050
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11032-
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11051+
// gemma 2 requires different mask for layers using sliding window (SWA)
11052+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11053+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
1103311054

1103411055
for (int il = 0; il < n_layer; ++il) {
11056+
// (il % 2) layers use SWA
11057+
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11058+
1103511059
// norm
1103611060
cur = llm_build_norm(ctx0, inpL, hparams,
1103711061
model.layers[il].attn_norm, NULL,
@@ -11067,7 +11091,7 @@ struct llm_build_context {
1106711091

1106811092
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
1106911093
model.layers[il].wo, NULL,
11070-
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
11094+
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
1107111095
}
1107211096

1107311097
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12670,7 +12694,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1267012694

1267112695
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1267212696

12673-
float * data = (float *) lctx.inp_KQ_mask->data;
12697+
float * data = (float *) lctx.inp_KQ_mask->data;
12698+
float * data_swa = nullptr;
12699+
12700+
if (lctx.inp_KQ_mask_swa) {
12701+
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
12702+
}
1267412703

1267512704
// For causal attention, use only the previous KV cells
1267612705
// of the correct sequence for each token of the batch.
@@ -12692,6 +12721,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1269212721
}
1269312722
}
1269412723
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12724+
12725+
// may need to cut off old tokens for sliding window
12726+
if (data_swa) {
12727+
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
12728+
f = -INFINITY;
12729+
}
12730+
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12731+
}
1269512732
}
1269612733
}
1269712734

0 commit comments

Comments
 (0)