@@ -7794,9 +7794,7 @@ static void llm_build_kv_store(
7794
7794
cb(k_cache_view, "k_cache_view", il);
7795
7795
7796
7796
// note: storing RoPE-ed version of K in the KV cache
7797
- ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798
- tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799
- ggml_build_forward_expand(graph, tmp);
7797
+ ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7800
7798
7801
7799
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
7802
7800
@@ -7814,9 +7812,7 @@ static void llm_build_kv_store(
7814
7812
v_cur = ggml_transpose(ctx, v_cur);
7815
7813
}
7816
7814
cb(v_cache_view, "v_cache_view", il);
7817
- tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818
- tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819
- ggml_build_forward_expand(graph, tmp);
7815
+ ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7820
7816
}
7821
7817
7822
7818
static struct ggml_tensor * llm_build_norm(
@@ -14607,43 +14603,41 @@ static int llama_decode_internal(
14607
14603
}
14608
14604
lctx.cached_graph.gf = gf;
14609
14605
14610
- if(ggml_use_cached_graph(lctx.sched)) {
14611
-
14612
- // Temporarily store KV cache parameters that will need updated in cached graph.
14606
+ // Update K and V cache parameters in cached graph.
14607
+ if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14608
+
14613
14609
const struct llama_hparams & hparams = model.hparams;
14614
- const int64_t n_layer = hparams.n_layer;
14615
14610
const int64_t kv_head = kv_self.head;
14616
- std::vector<void *> k_cache_ptrs;
14617
- std::vector<void *> v_cache_ptrs;
14618
- for (int il = 0; il < n_layer; ++il) {
14619
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14620
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14621
- ggml_tensor * tmp_tensor = kv_self.k_l[il];
14622
- size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14623
- k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14624
- tmp_tensor = kv_self.v_l[il];
14625
- if (cparams.flash_attn) {
14626
- tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14627
- } else {
14628
- tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14629
- }
14630
- v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14631
- }
14632
-
14633
- // Update KV cache parameters in cached graph.
14634
- int k_count = 0;
14635
- int v_count = 0;
14636
- if(gf != nullptr && gf->nodes != nullptr){
14637
- for (int i = 0; i < gf->n_nodes; i++) {
14638
- ggml_tensor * node = gf->nodes[i];
14639
- if (node->op == GGML_OP_CPY) {
14640
- if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14641
- node->src[1]->data = k_cache_ptrs[k_count++];
14642
- }
14643
- if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14644
- node->src[1]->data = v_cache_ptrs[v_count++];
14611
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14612
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14613
+
14614
+ for (int i = 0; i < gf->n_nodes; i++) {
14615
+ ggml_tensor * node = gf->nodes[i];
14616
+ if (node->op == GGML_OP_CPY) {
14617
+
14618
+ // K cache
14619
+ const char* k_prefix = "k_cache_view-";
14620
+ if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14621
+ int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14622
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14625
+ }
14626
+
14627
+ // V cache
14628
+ const char* v_prefix = "v_cache_view-";
14629
+ if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14630
+ int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14631
+ ggml_tensor * tmp_tensor = kv_self.v_l[il];
14632
+ size_t tmp_offset;
14633
+ if (cparams.flash_attn) {
14634
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14635
+ } else {
14636
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14645
14637
}
14638
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14646
14639
}
14640
+
14647
14641
}
14648
14642
}
14649
14643
0 commit comments