Skip to content

Commit ab2c3de

Browse files
committed
fix data_swa uninitialized
1 parent 7df7530 commit ab2c3de

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/llama.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1268712687

1268812688
float * data = (float *) lctx.inp_KQ_mask->data;
1268912689
float * data_swa = nullptr;
12690+
const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens;
1269012691

1269112692
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
1269212693
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
1269312694
GGML_ASSERT(hparams.n_ctx_swa > 0);
1269412695
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
1269512696
data = (float *) lctx.inp_KQ_mask_l[1]->data;
12697+
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
1269612698
}
1269712699

1269812700
// For causal attention, use only the previous KV cells
@@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1271712719
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
1271812720

1271912721
// may need to cut off old tokens for sliding window
12720-
if (data_swa && f != -INFINITY) {
12721-
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
12722-
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
12722+
if (data_swa) {
12723+
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
1272312724
f = -INFINITY;
1272412725
}
1272512726
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;

0 commit comments

Comments
 (0)