Skip to content

Commit bed4c73

Browse files
committed
another apporach
1 parent 5dec47d commit bed4c73

File tree

2 files changed

+65
-99
lines changed

2 files changed

+65
-99
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,8 +1317,8 @@ int llama_context::decode(llama_batch & inp_batch) {
13171317
n_outputs = n_outputs_new;
13181318
}
13191319

1320-
// non-causal masks do not use the KV cache
1321-
if (hparams.causal_attn) {
1320+
// find KV slot
1321+
{
13221322
kv_self_update();
13231323

13241324
// if we have enough unused cells before the current head ->

src/llama-graph.cpp

Lines changed: 63 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
402402

403403
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
404404
if (self_kq_mask || self_kq_mask_swa) {
405-
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406-
if (cparams.causal_attn) {
407-
const int64_t n_kv = kv_self->n;
408-
const int64_t n_tokens = ubatch->n_tokens;
409-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410-
const int64_t n_seqs = ubatch->n_seqs;
411-
412-
float * data = nullptr;
413-
float * data_swa = nullptr;
414-
415-
if (self_kq_mask) {
416-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417-
data = (float *) self_kq_mask->data;
418-
}
405+
const int64_t n_kv = kv_self->n;
406+
const int64_t n_tokens = ubatch->n_tokens;
407+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
408+
const int64_t n_seqs = ubatch->n_seqs;
419409

420-
if (self_kq_mask_swa) {
421-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422-
data_swa = (float *) self_kq_mask_swa->data;
423-
}
410+
float * data = nullptr;
411+
float * data_swa = nullptr;
424412

425-
// For causal attention, use only the previous KV cells
426-
// of the correct sequence for each token of the ubatch.
427-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
428-
for (int h = 0; h < 1; ++h) {
429-
for (int s = 0; s < n_seqs; ++s) {
430-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
413+
if (self_kq_mask) {
414+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
415+
data = (float *) self_kq_mask->data;
416+
}
431417

432-
for (int j = 0; j < n_seq_tokens; ++j) {
433-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
418+
if (self_kq_mask_swa) {
419+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
420+
data_swa = (float *) self_kq_mask_swa->data;
421+
}
434422

435-
for (int i = 0; i < n_kv; ++i) {
436-
float f;
437-
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
438-
f = -INFINITY;
423+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
424+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
425+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
426+
// Causal mask:
427+
// xxx-------
428+
// xxxx------
429+
// xxxxx-----
430+
// Non-causal mask:
431+
// xxxxx-----
432+
// xxxxx-----
433+
// xxxxx-----
434+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
435+
for (int h = 0; h < 1; ++h) {
436+
for (int s = 0; s < n_seqs; ++s) {
437+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
438+
439+
for (int j = 0; j < n_seq_tokens; ++j) {
440+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
441+
for (int i = 0; i < n_kv; ++i) {
442+
float f;
443+
// mask the token if:
444+
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
445+
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
446+
) {
447+
f = -INFINITY;
448+
} else {
449+
if (hparams.use_alibi) {
450+
f = -std::abs(kv_self->cells[i].pos - pos);
439451
} else {
440-
if (hparams.use_alibi) {
441-
f = -std::abs(kv_self->cells[i].pos - pos);
442-
} else {
443-
f = 0.0f;
444-
}
445-
}
446-
447-
if (data) {
448-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
452+
f = 0.0f;
449453
}
454+
}
450455

451-
// may need to cut off old tokens for sliding window
452-
if (data_swa) {
453-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
454-
f = -INFINITY;
455-
}
456-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
457-
}
456+
if (data) {
457+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
458458
}
459-
}
460-
}
461459

462-
if (data) {
463-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
464-
for (int j = 0; j < n_kv; ++j) {
465-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
460+
// may need to cut off old tokens for sliding window
461+
if (data_swa) {
462+
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
463+
f = -INFINITY;
464+
}
465+
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
466466
}
467467
}
468468
}
469+
}
469470

470-
if (data_swa) {
471-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
472-
for (int j = 0; j < n_kv; ++j) {
473-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
474-
}
471+
// mask padded tokens
472+
if (data) {
473+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
474+
for (int j = 0; j < n_kv; ++j) {
475+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
475476
}
476477
}
477478
}
478-
} else {
479-
const int64_t n_tokens = ubatch->n_tokens;
480-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
481-
const int64_t n_seqs = ubatch->n_seqs;
482-
// when using kv cache, the mask needs to match the kv cache size
483-
const int64_t n_stride = n_tokens;
484479

485-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
486-
487-
float * data = (float *) self_kq_mask->data;
488-
489-
for (int h = 0; h < 1; ++h) {
490-
for (int s1 = 0; s1 < n_seqs; ++s1) {
491-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
492-
493-
for (int j = 0; j < n_seq_tokens; ++j) {
494-
const int32_t tj = s1*n_seq_tokens + j;
495-
496-
for (int s0 = 0; s0 < n_seqs; ++s0) {
497-
for (int i = 0; i < n_seq_tokens; ++i) {
498-
const int32_t ti = s0*n_seq_tokens + i;
499-
float f = -INFINITY;
500-
501-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
502-
if (ubatch->seq_id[s0][s] == seq_id) {
503-
if (hparams.use_alibi) {
504-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
505-
} else {
506-
f = 0.0f;
507-
}
508-
break;
509-
}
510-
}
511-
512-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
513-
}
514-
}
515-
516-
for (int i = n_tokens; i < n_stride; ++i) {
517-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
518-
}
480+
// mask padded tokens
481+
if (data_swa) {
482+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
483+
for (int j = 0; j < n_kv; ++j) {
484+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
519485
}
520486
}
521487
}

0 commit comments

Comments
 (0)