@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746746 0 );
747747}
748748
749- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750750 const int32_t ikv = map_layer_ids.at (il);
751751
752752 auto * k = layers[ikv].k ;
753753
754754 const int64_t n_tokens = k_cur->ne [2 ];
755755
756+ if (kv_idxs) {
757+ return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
758+ }
759+
756760 ggml_tensor * k_view = ggml_view_1d (ctx, k,
757761 n_tokens*hparams.n_embd_k_gqa (il),
758762 ggml_row_size (k->type , hparams.n_embd_k_gqa (il))*head_cur);
759763
760764 return ggml_cpy (ctx, k_cur, k_view);
761765}
762766
763- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764768 const int32_t ikv = map_layer_ids.at (il);
765769
766770 auto * v = layers[ikv].v ;
@@ -772,10 +776,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772776 ggml_tensor * v_view = nullptr ;
773777
774778 if (!v_trans) {
779+ if (kv_idxs) {
780+ return ggml_set_rows (ctx, v, ggml_reshape_2d (ctx, v_cur, v->ne [0 ], n_tokens), kv_idxs);
781+ }
782+
775783 v_view = ggml_view_1d (ctx, v,
776784 n_tokens*hparams.n_embd_v_gqa (il),
777785 ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
778786 } else {
787+ if (kv_idxs) {
788+ GGML_ABORT (" TODO: implement kv_idxs for transposed V cache -- for now use flash attention" );
789+ }
790+
779791 // note: the V cache is transposed when not using flash attention
780792 v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
781793 (v->ne [1 ])*ggml_element_size (v),
@@ -787,6 +799,17 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
787799 return ggml_cpy (ctx, v_cur, v_view);
788800}
789801
802+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
803+ const uint32_t n_tokens = ubatch->n_tokens ;
804+
805+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
806+ int32_t * data = (int32_t *) dst->data ;
807+
808+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
809+ data[i] = head_cur + i;
810+ }
811+ }
812+
790813void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791814 const uint32_t n_tokens = ubatch->n_tokens ;
792815
@@ -1789,18 +1812,22 @@ ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il
17891812 return kv->get_v (ctx, il, n_kv);
17901813}
17911814
1792- ggml_tensor * llama_kv_cache_unified_state::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793- return kv->cpy_k (ctx, k_cur, il, head);
1815+ ggml_tensor * llama_kv_cache_unified_state::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1816+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, head);
17941817}
17951818
1796- ggml_tensor * llama_kv_cache_unified_state::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797- return kv->cpy_v (ctx, v_cur, il, head);
1819+ ggml_tensor * llama_kv_cache_unified_state::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1820+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, head);
17981821}
17991822
18001823void llama_kv_cache_unified_state::set_input_k_shift (ggml_tensor * dst) const {
18011824 kv->set_input_k_shift (dst);
18021825}
18031826
1827+ void llama_kv_cache_unified_state::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1828+ kv->set_input_kv_idxs (dst, ubatch, head);
1829+ }
1830+
18041831void llama_kv_cache_unified_state::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
18051832 kv->set_input_kq_mask (dst, ubatch, causal_attn);
18061833}
0 commit comments