Skip to content

Commit 78deb50

Browse files
committed
Avoid rebuild of GGML graph for each token
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable. Refs ggml-org#7456
1 parent 470939d commit 78deb50

File tree

3 files changed

+152
-8
lines changed

3 files changed

+152
-8
lines changed

ggml/include/ggml-backend.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ extern "C" {
230230
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
231231
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
232232

233+
// Utility to query whether cached GGML graph is in use
234+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
235+
236+
// Set whether or not to use GGML graph caching
237+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
233238

234239
#ifdef __cplusplus
235240
}

ggml/src/ggml-backend.c

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,13 @@ struct ggml_backend_sched_split {
10361036
struct ggml_cgraph graph;
10371037
};
10381038

1039+
// Object to facilitate GML graph caching
1040+
struct ggml_cached_graph {
1041+
bool is_active;
1042+
ggml_backend_t input_backend;
1043+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1044+
};
1045+
10391046
struct ggml_backend_sched {
10401047
bool is_reset; // true if the scheduler has been reset since the last graph split
10411048
bool is_alloc;
@@ -1087,6 +1094,8 @@ struct ggml_backend_sched {
10871094
__attribute__((aligned(GGML_MEM_ALIGN)))
10881095
#endif
10891096
char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
1097+
1098+
struct ggml_cached_graph cached_graph;
10901099
};
10911100

10921101
#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
@@ -1753,6 +1762,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17531762
struct ggml_tensor * input = split->inputs[j];
17541763
struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
17551764

1765+
if (!sched->cached_graph.is_active) {
1766+
sched->cached_graph.input_backend = input_backend;
1767+
sched->cached_graph.input_cpy[j] = input_cpy;
1768+
}
1769+
else {
1770+
input_backend = sched->cached_graph.input_backend;
1771+
input_cpy = sched->cached_graph.input_cpy[j];
1772+
}
17561773
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17571774
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17581775
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1872,6 +1889,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18721889

18731890
ggml_backend_sched_reset(sched);
18741891

1892+
sched->cached_graph.is_active = false;
1893+
18751894
return sched;
18761895
}
18771896

@@ -1947,6 +1966,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19471966
}
19481967

19491968
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1969+
1970+
if(!sched->cached_graph.is_active)
1971+
{
19501972
if (!sched->is_reset && !sched->is_alloc) {
19511973
ggml_backend_sched_reset(sched);
19521974
}
@@ -1956,7 +1978,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
19561978
return GGML_STATUS_ALLOC_FAILED;
19571979
}
19581980
}
1959-
1981+
}
19601982
return ggml_backend_sched_compute_splits(sched);
19611983
}
19621984

@@ -2223,3 +2245,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22232245

22242246
return true;
22252247
}
2248+
2249+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2250+
return sched->cached_graph.is_active;
2251+
}
2252+
2253+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2254+
sched->cached_graph.is_active = set_value;
2255+
}
2256+

src/llama.cpp

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,6 +2712,16 @@ struct llama_model {
27122712
}
27132713
};
27142714

2715+
// Object used to allow caching of GGML graph between tokens where possible.
2716+
struct ggml_cached_graph {
2717+
ggml_cgraph * gf;
2718+
size_t n;
2719+
ggml_backend_t backend_res;
2720+
ggml_backend_t backend_embd;
2721+
struct ggml_tensor * res;
2722+
struct ggml_tensor * embd;
2723+
};
2724+
27152725
struct llama_context {
27162726
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
27172727
~llama_context() {
@@ -2813,6 +2823,8 @@ struct llama_context {
28132823

28142824
// control vectors
28152825
struct llama_control_vector cvec;
2826+
2827+
struct ggml_cached_graph cached_graph;
28162828
};
28172829

28182830
static size_t llama_get_device_count(const llama_model & model) {
@@ -14524,12 +14536,37 @@ static int llama_decode_internal(
1452414536
ggml_backend_sched_reset(lctx.sched);
1452514537
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1452614538

14527-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14528-
14539+
ggml_cgraph * gf;
1452914540
// the output is always the last tensor in the graph
14530-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14531-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14541+
struct ggml_tensor * res;
14542+
struct ggml_tensor * embd;
14543+
14544+
bool n_has_changed_since_last_token = false;
14545+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14546+
lctx.cached_graph.n = kv_self.n;
14547+
14548+
// Re-build graph only if graph caching is not possible
14549+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14550+
14551+
gf = llama_build_graph(lctx, u_batch, false);
14552+
14553+
// disable future graph caching in presense of env var,
14554+
// if there are multiple devices, or if batch size is greater than 1
14555+
// TO DO enable graph caching for these cases
14556+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14557+
|| (llama_get_device_count(model) > 1);
14558+
for (int i = 0 ; i < gf->n_nodes; i++) {
14559+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14560+
disable_cached_ggml_graph = true;
14561+
break;
14562+
}
14563+
}
14564+
14565+
if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
1453214566

14567+
// the output is always the last tensor in the graph
14568+
res = gf->nodes[gf->n_nodes - 1];
14569+
embd = gf->nodes[gf->n_nodes - 2];
1453314570
if (lctx.n_outputs == 0) {
1453414571
// no output
1453514572
res = nullptr;
@@ -14545,10 +14582,71 @@ static int llama_decode_internal(
1454514582
embd = nullptr; // do not extract embeddings when not needed
1454614583
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1454714584
}
14585+
lctx.cached_graph.res = res;
14586+
lctx.cached_graph.embd = embd;
1454814587
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1454914588

1455014589
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1455114590

14591+
}
14592+
else {
14593+
gf = lctx.cached_graph.gf;
14594+
res = lctx.cached_graph.res;
14595+
embd = lctx.cached_graph.embd;
14596+
}
14597+
lctx.cached_graph.gf = gf;
14598+
14599+
if(ggml_use_cached_graph(lctx.sched)) {
14600+
14601+
// If using flash attention, find mask node so it can be skipped when updating
14602+
// KV cache paramaters in cached graph nodes below
14603+
void * flash_attn_mask_node = nullptr;
14604+
if(cparams.flash_attn) {
14605+
for (int i = 0; i < gf->n_nodes; i++) {
14606+
ggml_tensor * node = gf->nodes[i];
14607+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14608+
flash_attn_mask_node = node->src[3];
14609+
break;
14610+
}
14611+
}
14612+
}
14613+
14614+
// Temporarily store KV cache parameters that will need updated in cached graph.
14615+
const struct llama_hparams & hparams = model.hparams;
14616+
const int64_t n_layer = hparams.n_layer;
14617+
const int64_t kv_head = kv_self.head;
14618+
std::vector<void *> kv_cache_ptrs;
14619+
for (int il = 0; il < n_layer; ++il) {
14620+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14621+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14622+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625+
tmp_tensor = kv_self.v_l[il];
14626+
if (cparams.flash_attn) {
14627+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14628+
} else {
14629+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14630+
}
14631+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14632+
}
14633+
14634+
// Update KV cache parameters in cached graph.
14635+
int copy_op_count = 0;
14636+
if(gf != nullptr && gf->nodes != nullptr){
14637+
for (int i = 0; i < gf->n_nodes; i++) {
14638+
ggml_tensor * node = gf->nodes[i];
14639+
if (node->op == GGML_OP_CPY) {
14640+
if (node != flash_attn_mask_node) {
14641+
node->src[1]->data = kv_cache_ptrs[copy_op_count];
14642+
copy_op_count++;
14643+
}
14644+
}
14645+
}
14646+
}
14647+
14648+
}
14649+
1455214650
llama_set_inputs(lctx, u_batch);
1455314651

1455414652
llama_graph_compute(lctx, gf, n_threads);
@@ -14571,11 +14669,15 @@ static int llama_decode_internal(
1457114669
// extract logits
1457214670
if (res) {
1457314671
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14574-
GGML_ASSERT(backend_res != nullptr);
14575-
GGML_ASSERT(lctx.logits != nullptr);
14576-
1457714672
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1457814673
const int32_t n_outputs_new = lctx.n_outputs;
14674+
if(!ggml_use_cached_graph(lctx.sched))
14675+
lctx.cached_graph.backend_res = backend_res;
14676+
else
14677+
backend_res = lctx.cached_graph.backend_res;
14678+
14679+
GGML_ASSERT(backend_res != nullptr);
14680+
GGML_ASSERT(lctx.logits != nullptr);
1457914681

1458014682
if (n_outputs_new) {
1458114683
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14587,6 +14689,12 @@ static int llama_decode_internal(
1458714689
// extract embeddings
1458814690
if (embd) {
1458914691
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14692+
14693+
14694+
if(!ggml_use_cached_graph(lctx.sched))
14695+
lctx.cached_graph.backend_embd = backend_embd;
14696+
else
14697+
backend_embd = lctx.cached_graph.backend_embd;
1459014698
GGML_ASSERT(backend_embd != nullptr);
1459114699

1459214700
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)