@@ -2649,6 +2649,17 @@ struct llama_model {
2649
2649
}
2650
2650
};
2651
2651
2652
+ // Object used to allow caching of GGML graph between tokens where possible.
2653
+ struct ggml_cached_graph {
2654
+ bool is_active = false;
2655
+ ggml_cgraph * gf;
2656
+ size_t n;
2657
+ ggml_backend_t backend_res;
2658
+ ggml_backend_t backend_embd;
2659
+ struct ggml_tensor * res;
2660
+ struct ggml_tensor * embd;
2661
+ };
2662
+
2652
2663
struct llama_context {
2653
2664
llama_context(const llama_model & model)
2654
2665
: model(model)
@@ -2749,6 +2760,8 @@ struct llama_context {
2749
2760
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
2750
2761
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
2751
2762
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
2763
+
2764
+ struct ggml_cached_graph cached_graph;
2752
2765
};
2753
2766
2754
2767
struct llama_lora_weight {
@@ -7886,7 +7899,6 @@ static void llm_build_kv_store(
7886
7899
v_cur = ggml_transpose(ctx, v_cur);
7887
7900
}
7888
7901
cb(v_cache_view, "v_cache_view", il);
7889
-
7890
7902
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7891
7903
}
7892
7904
@@ -14695,12 +14707,44 @@ static int llama_decode_internal(
14695
14707
ggml_backend_sched_reset(lctx.sched);
14696
14708
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
14697
14709
14698
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14699
-
14710
+ ggml_cgraph * gf;
14700
14711
// the output is always the last tensor in the graph
14701
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14702
- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14712
+ struct ggml_tensor * res;
14713
+ struct ggml_tensor * embd;
14714
+
14715
+ bool n_has_changed_since_last_token = false;
14716
+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14717
+ lctx.cached_graph.n = kv_self.n;
14718
+
14719
+ // Re-build graph only if graph caching is not possible
14720
+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14721
+
14722
+ gf = llama_build_graph(lctx, u_batch, false);
14723
+
14724
+ // Set whether GGML graph caching is in use within GGML module, based on
14725
+ // whether caching was activated here during the previous token
14726
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14727
+
14728
+ // Disable future graph caching in presence of env var,
14729
+ // if there are multiple devices, if batch size is greater than 1,
14730
+ // or if nsplits is not 2.
14731
+ // TO DO enable graph caching for these cases
14732
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14733
+ || (llama_get_device_count(model) > 1)
14734
+ || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14735
+ for (int i = 0 ; i < gf->n_nodes; i++) {
14736
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14737
+ disable_cached_ggml_graph = true;
14738
+ break;
14739
+ }
14740
+ }
14741
+
14742
+ // Set whether graph caching should be used for future tokens
14743
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14703
14744
14745
+ // the output is always the last tensor in the graph
14746
+ res = gf->nodes[gf->n_nodes - 1];
14747
+ embd = gf->nodes[gf->n_nodes - 2];
14704
14748
if (lctx.n_outputs == 0) {
14705
14749
// no output
14706
14750
res = nullptr;
@@ -14716,10 +14760,60 @@ static int llama_decode_internal(
14716
14760
embd = nullptr; // do not extract embeddings when not needed
14717
14761
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14718
14762
}
14763
+ lctx.cached_graph.res = res;
14764
+ lctx.cached_graph.embd = embd;
14719
14765
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14720
14766
14721
14767
ggml_backend_sched_alloc_graph(lctx.sched, gf);
14722
14768
14769
+ }
14770
+ else {
14771
+ gf = lctx.cached_graph.gf;
14772
+ res = lctx.cached_graph.res;
14773
+ embd = lctx.cached_graph.embd;
14774
+ }
14775
+ lctx.cached_graph.gf = gf;
14776
+
14777
+ // Update K and V cache parameters in cached graph.
14778
+ if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14779
+
14780
+ const struct llama_hparams & hparams = model.hparams;
14781
+ const int64_t kv_head = kv_self.head;
14782
+
14783
+ for (int i = 0; i < gf->n_nodes; i++) {
14784
+ ggml_tensor * node = gf->nodes[i];
14785
+ if (node->op == GGML_OP_CPY) {
14786
+
14787
+ // K cache
14788
+ const char* k_prefix = "k_cache_view-";
14789
+ if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14790
+ int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14791
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
14792
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14793
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14794
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14795
+ }
14796
+
14797
+ // V cache
14798
+ const char* v_prefix = "v_cache_view-";
14799
+ if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14800
+ int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14801
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14802
+ ggml_tensor * tmp_tensor = kv_self.v_l[il];
14803
+ size_t tmp_offset;
14804
+ if (cparams.flash_attn) {
14805
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14806
+ } else {
14807
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14808
+ }
14809
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14810
+ }
14811
+
14812
+ }
14813
+ }
14814
+
14815
+ }
14816
+
14723
14817
llama_set_inputs(lctx, u_batch);
14724
14818
14725
14819
llama_graph_compute(lctx, gf, n_threads);
@@ -14742,11 +14836,15 @@ static int llama_decode_internal(
14742
14836
// extract logits
14743
14837
if (res) {
14744
14838
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14745
- GGML_ASSERT(backend_res != nullptr);
14746
- GGML_ASSERT(lctx.logits != nullptr);
14747
-
14748
14839
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
14749
14840
const int32_t n_outputs_new = lctx.n_outputs;
14841
+ if(!ggml_use_cached_graph(lctx.sched))
14842
+ lctx.cached_graph.backend_res = backend_res;
14843
+ else
14844
+ backend_res = lctx.cached_graph.backend_res;
14845
+
14846
+ GGML_ASSERT(backend_res != nullptr);
14847
+ GGML_ASSERT(lctx.logits != nullptr);
14750
14848
14751
14849
if (n_outputs_new) {
14752
14850
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14758,6 +14856,12 @@ static int llama_decode_internal(
14758
14856
// extract embeddings
14759
14857
if (embd) {
14760
14858
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14859
+
14860
+
14861
+ if(!ggml_use_cached_graph(lctx.sched))
14862
+ lctx.cached_graph.backend_embd = backend_embd;
14863
+ else
14864
+ backend_embd = lctx.cached_graph.backend_embd;
14761
14865
GGML_ASSERT(backend_embd != nullptr);
14762
14866
14763
14867
switch (cparams.pooling_type) {
0 commit comments