@@ -2648,6 +2648,17 @@ struct llama_model {
2648
2648
}
2649
2649
};
2650
2650
2651
+ // Object used to allow caching of GGML graph between tokens where possible.
2652
+ struct ggml_cached_graph {
2653
+ bool is_active = false;
2654
+ ggml_cgraph * gf;
2655
+ size_t n;
2656
+ ggml_backend_t backend_res;
2657
+ ggml_backend_t backend_embd;
2658
+ struct ggml_tensor * res;
2659
+ struct ggml_tensor * embd;
2660
+ };
2661
+
2651
2662
struct llama_context {
2652
2663
llama_context(const llama_model & model)
2653
2664
: model(model)
@@ -2748,6 +2759,10 @@ struct llama_context {
2748
2759
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
2749
2760
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
2750
2761
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
2762
+
2763
+ // cached Cuda Graphs
2764
+ struct ggml_cached_graph cached_graph;
2765
+
2751
2766
};
2752
2767
2753
2768
struct llama_lora_weight {
@@ -7902,7 +7917,9 @@ static void llm_build_kv_store(
7902
7917
cb(k_cache_view, "k_cache_view", il);
7903
7918
7904
7919
// note: storing RoPE-ed version of K in the KV cache
7905
- ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7920
+ ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7921
+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7922
+ ggml_build_forward_expand(graph, tmp);
7906
7923
7907
7924
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
7908
7925
@@ -7920,8 +7937,9 @@ static void llm_build_kv_store(
7920
7937
v_cur = ggml_transpose(ctx, v_cur);
7921
7938
}
7922
7939
cb(v_cache_view, "v_cache_view", il);
7923
-
7924
- ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7940
+ tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7941
+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7942
+ ggml_build_forward_expand(graph, tmp);
7925
7943
}
7926
7944
7927
7945
// do mat_mul, while optionally apply lora
@@ -14729,12 +14747,44 @@ static int llama_decode_internal(
14729
14747
ggml_backend_sched_reset(lctx.sched);
14730
14748
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
14731
14749
14732
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14733
-
14750
+ ggml_cgraph * gf;
14734
14751
// the output is always the last tensor in the graph
14735
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14736
- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14752
+ struct ggml_tensor * res;
14753
+ struct ggml_tensor * embd;
14754
+
14755
+ bool n_has_changed_since_last_token = false;
14756
+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14757
+ lctx.cached_graph.n = kv_self.n;
14758
+
14759
+ // Re-build graph only if graph caching is not possible
14760
+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14761
+
14762
+ gf = llama_build_graph(lctx, u_batch, false);
14763
+
14764
+ // Set whether GGML graph caching is in use within GGML module, based on
14765
+ // whether caching was activated here during the previous token
14766
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14767
+
14768
+ // Disable future graph caching in presence of env var,
14769
+ // if there are multiple devices, if batch size is greater than 1,
14770
+ // or if nsplits is not 2.
14771
+ // TO DO enable graph caching for these cases
14772
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14773
+ || (llama_get_device_count(model) > 1)
14774
+ || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14775
+ for (int i = 0 ; i < gf->n_nodes; i++) {
14776
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14777
+ disable_cached_ggml_graph = true;
14778
+ break;
14779
+ }
14780
+ }
14781
+
14782
+ // Set whether graph caching should be used for future tokens
14783
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14737
14784
14785
+ // the output is always the last tensor in the graph
14786
+ res = gf->nodes[gf->n_nodes - 1];
14787
+ embd = gf->nodes[gf->n_nodes - 2];
14738
14788
if (lctx.n_outputs == 0) {
14739
14789
// no output
14740
14790
res = nullptr;
@@ -14750,10 +14800,62 @@ static int llama_decode_internal(
14750
14800
embd = nullptr; // do not extract embeddings when not needed
14751
14801
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14752
14802
}
14803
+ lctx.cached_graph.res = res;
14804
+ lctx.cached_graph.embd = embd;
14753
14805
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14754
14806
14755
14807
ggml_backend_sched_alloc_graph(lctx.sched, gf);
14756
14808
14809
+ }
14810
+ else {
14811
+ gf = lctx.cached_graph.gf;
14812
+ res = lctx.cached_graph.res;
14813
+ embd = lctx.cached_graph.embd;
14814
+ }
14815
+ lctx.cached_graph.gf = gf;
14816
+
14817
+ if(ggml_use_cached_graph(lctx.sched)) {
14818
+
14819
+ // Temporarily store KV cache parameters that will need updated in cached graph.
14820
+ const struct llama_hparams & hparams = model.hparams;
14821
+ const int64_t n_layer = hparams.n_layer;
14822
+ const int64_t kv_head = kv_self.head;
14823
+ std::vector<void *> k_cache_ptrs;
14824
+ std::vector<void *> v_cache_ptrs;
14825
+ for (int il = 0; il < n_layer; ++il) {
14826
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
14827
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14828
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14829
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14830
+ k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14831
+ tmp_tensor = kv_self.v_l[il];
14832
+ if (cparams.flash_attn) {
14833
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14834
+ } else {
14835
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14836
+ }
14837
+ v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14838
+ }
14839
+
14840
+ // Update KV cache parameters in cached graph.
14841
+ int k_count = 0;
14842
+ int v_count = 0;
14843
+ if(gf != nullptr && gf->nodes != nullptr){
14844
+ for (int i = 0; i < gf->n_nodes; i++) {
14845
+ ggml_tensor * node = gf->nodes[i];
14846
+ if (node->op == GGML_OP_CPY) {
14847
+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14848
+ node->src[1]->data = k_cache_ptrs[k_count++];
14849
+ }
14850
+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14851
+ node->src[1]->data = v_cache_ptrs[v_count++];
14852
+ }
14853
+ }
14854
+ }
14855
+ }
14856
+
14857
+ }
14858
+
14757
14859
llama_set_inputs(lctx, u_batch);
14758
14860
14759
14861
llama_graph_compute(lctx, gf, n_threads);
@@ -14776,11 +14878,15 @@ static int llama_decode_internal(
14776
14878
// extract logits
14777
14879
if (res) {
14778
14880
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14779
- GGML_ASSERT(backend_res != nullptr);
14780
- GGML_ASSERT(lctx.logits != nullptr);
14781
-
14782
14881
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
14783
14882
const int32_t n_outputs_new = lctx.n_outputs;
14883
+ if(!ggml_use_cached_graph(lctx.sched))
14884
+ lctx.cached_graph.backend_res = backend_res;
14885
+ else
14886
+ backend_res = lctx.cached_graph.backend_res;
14887
+
14888
+ GGML_ASSERT(backend_res != nullptr);
14889
+ GGML_ASSERT(lctx.logits != nullptr);
14784
14890
14785
14891
if (n_outputs_new) {
14786
14892
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14792,6 +14898,12 @@ static int llama_decode_internal(
14792
14898
// extract embeddings
14793
14899
if (embd) {
14794
14900
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14901
+
14902
+
14903
+ if(!ggml_use_cached_graph(lctx.sched))
14904
+ lctx.cached_graph.backend_embd = backend_embd;
14905
+ else
14906
+ backend_embd = lctx.cached_graph.backend_embd;
14795
14907
GGML_ASSERT(backend_embd != nullptr);
14796
14908
14797
14909
switch (cparams.pooling_type) {
0 commit comments