@@ -2712,6 +2712,16 @@ struct llama_model {
2712
2712
}
2713
2713
};
2714
2714
2715
+ // Object used to allow caching of GGML graph between tokens where possible.
2716
+ struct ggml_cached_graph {
2717
+ ggml_cgraph * gf;
2718
+ size_t n;
2719
+ ggml_backend_t backend_res;
2720
+ ggml_backend_t backend_embd;
2721
+ struct ggml_tensor * res;
2722
+ struct ggml_tensor * embd;
2723
+ };
2724
+
2715
2725
struct llama_context {
2716
2726
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
2717
2727
~llama_context() {
@@ -2813,6 +2823,8 @@ struct llama_context {
2813
2823
2814
2824
// control vectors
2815
2825
struct llama_control_vector cvec;
2826
+
2827
+ struct ggml_cached_graph cached_graph;
2816
2828
};
2817
2829
2818
2830
static size_t llama_get_device_count(const llama_model & model) {
@@ -14524,12 +14536,37 @@ static int llama_decode_internal(
14524
14536
ggml_backend_sched_reset(lctx.sched);
14525
14537
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
14526
14538
14527
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14528
-
14539
+ ggml_cgraph * gf;
14529
14540
// the output is always the last tensor in the graph
14530
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14531
- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14541
+ struct ggml_tensor * res;
14542
+ struct ggml_tensor * embd;
14543
+
14544
+ bool n_has_changed_since_last_token = false;
14545
+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14546
+ lctx.cached_graph.n = kv_self.n;
14547
+
14548
+ // Re-build graph only if graph caching is not possible
14549
+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14550
+
14551
+ gf = llama_build_graph(lctx, u_batch, false);
14552
+
14553
+ // disable future graph caching in presense of env var,
14554
+ // if there are multiple devices, or if batch size is greater than 1
14555
+ // TO DO enable graph caching for these cases
14556
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14557
+ || (llama_get_device_count(model) > 1);
14558
+ for (int i = 0 ; i < gf->n_nodes; i++) {
14559
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14560
+ disable_cached_ggml_graph = true;
14561
+ break;
14562
+ }
14563
+ }
14564
+
14565
+ if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
14532
14566
14567
+ // the output is always the last tensor in the graph
14568
+ res = gf->nodes[gf->n_nodes - 1];
14569
+ embd = gf->nodes[gf->n_nodes - 2];
14533
14570
if (lctx.n_outputs == 0) {
14534
14571
// no output
14535
14572
res = nullptr;
@@ -14545,10 +14582,71 @@ static int llama_decode_internal(
14545
14582
embd = nullptr; // do not extract embeddings when not needed
14546
14583
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14547
14584
}
14585
+ lctx.cached_graph.res = res;
14586
+ lctx.cached_graph.embd = embd;
14548
14587
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14549
14588
14550
14589
ggml_backend_sched_alloc_graph(lctx.sched, gf);
14551
14590
14591
+ }
14592
+ else {
14593
+ gf = lctx.cached_graph.gf;
14594
+ res = lctx.cached_graph.res;
14595
+ embd = lctx.cached_graph.embd;
14596
+ }
14597
+ lctx.cached_graph.gf = gf;
14598
+
14599
+ if(ggml_use_cached_graph(lctx.sched)) {
14600
+
14601
+ // If using flash attention, find mask node so it can be skipped when updating
14602
+ // KV cache paramaters in cached graph nodes below
14603
+ void * flash_attn_mask_node = nullptr;
14604
+ if(cparams.flash_attn) {
14605
+ for (int i = 0; i < gf->n_nodes; i++) {
14606
+ ggml_tensor * node = gf->nodes[i];
14607
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14608
+ flash_attn_mask_node = node->src[3];
14609
+ break;
14610
+ }
14611
+ }
14612
+ }
14613
+
14614
+ // Temporarily store KV cache parameters that will need updated in cached graph.
14615
+ const struct llama_hparams & hparams = model.hparams;
14616
+ const int64_t n_layer = hparams.n_layer;
14617
+ const int64_t kv_head = kv_self.head;
14618
+ std::vector<void *> kv_cache_ptrs;
14619
+ for (int il = 0; il < n_layer; ++il) {
14620
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14621
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14622
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624
+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625
+ tmp_tensor = kv_self.v_l[il];
14626
+ if (cparams.flash_attn) {
14627
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14628
+ } else {
14629
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14630
+ }
14631
+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14632
+ }
14633
+
14634
+ // Update KV cache parameters in cached graph.
14635
+ int copy_op_count = 0;
14636
+ if(gf != nullptr && gf->nodes != nullptr){
14637
+ for (int i = 0; i < gf->n_nodes; i++) {
14638
+ ggml_tensor * node = gf->nodes[i];
14639
+ if (node->op == GGML_OP_CPY) {
14640
+ if (node != flash_attn_mask_node) {
14641
+ node->src[1]->data = kv_cache_ptrs[copy_op_count];
14642
+ copy_op_count++;
14643
+ }
14644
+ }
14645
+ }
14646
+ }
14647
+
14648
+ }
14649
+
14552
14650
llama_set_inputs(lctx, u_batch);
14553
14651
14554
14652
llama_graph_compute(lctx, gf, n_threads);
@@ -14571,11 +14669,15 @@ static int llama_decode_internal(
14571
14669
// extract logits
14572
14670
if (res) {
14573
14671
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14574
- GGML_ASSERT(backend_res != nullptr);
14575
- GGML_ASSERT(lctx.logits != nullptr);
14576
-
14577
14672
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
14578
14673
const int32_t n_outputs_new = lctx.n_outputs;
14674
+ if(!ggml_use_cached_graph(lctx.sched))
14675
+ lctx.cached_graph.backend_res = backend_res;
14676
+ else
14677
+ backend_res = lctx.cached_graph.backend_res;
14678
+
14679
+ GGML_ASSERT(backend_res != nullptr);
14680
+ GGML_ASSERT(lctx.logits != nullptr);
14579
14681
14580
14682
if (n_outputs_new) {
14581
14683
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14587,6 +14689,12 @@ static int llama_decode_internal(
14587
14689
// extract embeddings
14588
14690
if (embd) {
14589
14691
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14692
+
14693
+
14694
+ if(!ggml_use_cached_graph(lctx.sched))
14695
+ lctx.cached_graph.backend_embd = backend_embd;
14696
+ else
14697
+ backend_embd = lctx.cached_graph.backend_embd;
14590
14698
GGML_ASSERT(backend_embd != nullptr);
14591
14699
14592
14700
switch (cparams.pooling_type) {
0 commit comments