@@ -2745,6 +2745,17 @@ struct llama_model {
2745
2745
}
2746
2746
};
2747
2747
2748
+ // Object used to allow caching of GGML graph between tokens where possible.
2749
+ struct ggml_cached_graph {
2750
+ bool is_active = false;
2751
+ ggml_cgraph * gf;
2752
+ size_t n;
2753
+ ggml_backend_t backend_res;
2754
+ ggml_backend_t backend_embd;
2755
+ struct ggml_tensor * res;
2756
+ struct ggml_tensor * embd;
2757
+ };
2758
+
2748
2759
struct llama_context {
2749
2760
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
2750
2761
~llama_context() {
@@ -2846,6 +2857,8 @@ struct llama_context {
2846
2857
2847
2858
// control vectors
2848
2859
struct llama_control_vector cvec;
2860
+
2861
+ struct ggml_cached_graph cached_graph;
2849
2862
};
2850
2863
2851
2864
static size_t llama_get_device_count(const llama_model & model) {
@@ -14668,12 +14681,42 @@ static int llama_decode_internal(
14668
14681
ggml_backend_sched_reset(lctx.sched);
14669
14682
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
14670
14683
14671
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14672
-
14684
+ ggml_cgraph * gf;
14673
14685
// the output is always the last tensor in the graph
14674
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14675
- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14686
+ struct ggml_tensor * res;
14687
+ struct ggml_tensor * embd;
14688
+
14689
+ bool n_has_changed_since_last_token = false;
14690
+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14691
+ lctx.cached_graph.n = kv_self.n;
14692
+
14693
+ // Re-build graph only if graph caching is not possible
14694
+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14695
+
14696
+ gf = llama_build_graph(lctx, u_batch, false);
14697
+
14698
+ // Set whether GGML graph caching is in use within GGML module, based on
14699
+ // whether caching was activated here during the previous token
14700
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14701
+
14702
+ // Disable future graph caching in presence of env var,
14703
+ // if there are multiple devices, or if batch size is greater than 1
14704
+ // TO DO enable graph caching for these cases
14705
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14706
+ || (llama_get_device_count(model) > 1);
14707
+ for (int i = 0 ; i < gf->n_nodes; i++) {
14708
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14709
+ disable_cached_ggml_graph = true;
14710
+ break;
14711
+ }
14712
+ }
14713
+
14714
+ // Set whether graph caching should be used for future tokens
14715
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14676
14716
14717
+ // the output is always the last tensor in the graph
14718
+ res = gf->nodes[gf->n_nodes - 1];
14719
+ embd = gf->nodes[gf->n_nodes - 2];
14677
14720
if (lctx.n_outputs == 0) {
14678
14721
// no output
14679
14722
res = nullptr;
@@ -14689,10 +14732,71 @@ static int llama_decode_internal(
14689
14732
embd = nullptr; // do not extract embeddings when not needed
14690
14733
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14691
14734
}
14735
+ lctx.cached_graph.res = res;
14736
+ lctx.cached_graph.embd = embd;
14692
14737
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14693
14738
14694
14739
ggml_backend_sched_alloc_graph(lctx.sched, gf);
14695
14740
14741
+ }
14742
+ else {
14743
+ gf = lctx.cached_graph.gf;
14744
+ res = lctx.cached_graph.res;
14745
+ embd = lctx.cached_graph.embd;
14746
+ }
14747
+ lctx.cached_graph.gf = gf;
14748
+
14749
+ if(ggml_use_cached_graph(lctx.sched)) {
14750
+
14751
+ // If using flash attention, find mask node so it can be skipped when updating
14752
+ // KV cache paramaters in cached graph nodes below
14753
+ void * flash_attn_mask_node = nullptr;
14754
+ if(cparams.flash_attn) {
14755
+ for (int i = 0; i < gf->n_nodes; i++) {
14756
+ ggml_tensor * node = gf->nodes[i];
14757
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14758
+ flash_attn_mask_node = node->src[3];
14759
+ break;
14760
+ }
14761
+ }
14762
+ }
14763
+
14764
+ // Temporarily store KV cache parameters that will need updated in cached graph.
14765
+ const struct llama_hparams & hparams = model.hparams;
14766
+ const int64_t n_layer = hparams.n_layer;
14767
+ const int64_t kv_head = kv_self.head;
14768
+ std::vector<void *> kv_cache_ptrs;
14769
+ for (int il = 0; il < n_layer; ++il) {
14770
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14771
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14772
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14773
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14774
+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14775
+ tmp_tensor = kv_self.v_l[il];
14776
+ if (cparams.flash_attn) {
14777
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14778
+ } else {
14779
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14780
+ }
14781
+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14782
+ }
14783
+
14784
+ // Update KV cache parameters in cached graph.
14785
+ int copy_op_count = 0;
14786
+ if(gf != nullptr && gf->nodes != nullptr){
14787
+ for (int i = 0; i < gf->n_nodes; i++) {
14788
+ ggml_tensor * node = gf->nodes[i];
14789
+ if (node->op == GGML_OP_CPY) {
14790
+ if (node != flash_attn_mask_node) {
14791
+ node->src[1]->data = kv_cache_ptrs[copy_op_count];
14792
+ copy_op_count++;
14793
+ }
14794
+ }
14795
+ }
14796
+ }
14797
+
14798
+ }
14799
+
14696
14800
llama_set_inputs(lctx, u_batch);
14697
14801
14698
14802
llama_graph_compute(lctx, gf, n_threads);
@@ -14715,11 +14819,15 @@ static int llama_decode_internal(
14715
14819
// extract logits
14716
14820
if (res) {
14717
14821
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14718
- GGML_ASSERT(backend_res != nullptr);
14719
- GGML_ASSERT(lctx.logits != nullptr);
14720
-
14721
14822
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
14722
14823
const int32_t n_outputs_new = lctx.n_outputs;
14824
+ if(!ggml_use_cached_graph(lctx.sched))
14825
+ lctx.cached_graph.backend_res = backend_res;
14826
+ else
14827
+ backend_res = lctx.cached_graph.backend_res;
14828
+
14829
+ GGML_ASSERT(backend_res != nullptr);
14830
+ GGML_ASSERT(lctx.logits != nullptr);
14723
14831
14724
14832
if (n_outputs_new) {
14725
14833
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14731,6 +14839,12 @@ static int llama_decode_internal(
14731
14839
// extract embeddings
14732
14840
if (embd) {
14733
14841
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14842
+
14843
+
14844
+ if(!ggml_use_cached_graph(lctx.sched))
14845
+ lctx.cached_graph.backend_embd = backend_embd;
14846
+ else
14847
+ backend_embd = lctx.cached_graph.backend_embd;
14734
14848
GGML_ASSERT(backend_embd != nullptr);
14735
14849
14736
14850
switch (cparams.pooling_type) {
0 commit comments