diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aa363df6356ea..7e6ae2fd32b5e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1317,8 +1317,8 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { + // find KV slot + { kv_self_update(); // if we have enough unused cells before the current head -> diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..cec203df49268 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask || self_kq_mask_swa) { - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { - const int64_t n_kv = kv_self->n; - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - float * data = nullptr; - float * data_swa = nullptr; - - if (self_kq_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - data = (float *) self_kq_mask->data; - } + const int64_t n_kv = kv_self->n; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; - if (self_kq_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - data_swa = (float *) self_kq_mask_swa->data; - } + float * data = nullptr; + float * data_swa = nullptr; - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; + if (self_kq_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + data = (float *) self_kq_mask->data; + } - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + if (self_kq_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + data_swa = (float *) self_kq_mask_swa->data; + } - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) { - f = -INFINITY; + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + for (int i = 0; i < n_kv; ++i) { + float f; + // mask the token if: + if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence + || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens + ) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self->cells[i].pos - pos); } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self->cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + f = 0.0f; } + } - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - } - } - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } } } + } - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } } - } else { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch->seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { - if (ubatch->seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } + // mask padded tokens + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } }