Skip to content

Commit b155949

Browse files
agray3Nexesenex
authored andcommitted
ggml: avoid rebuild of GGML graph for each token (ggml-org#7456)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable. fix seg fault restrict to nsplit=2 Improve identification of K and V nodes for param updates remove stale code Reworked to directly update KV cache params using info from name make n_embd_v_gqa_* dependent on layer
1 parent eeee35c commit b155949

File tree

3 files changed

+155
-15
lines changed

3 files changed

+155
-15
lines changed

ggml/include/ggml-backend.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ extern "C" {
232232
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
233233
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
234234

235+
// Utility to query whether cached GGML graph is in use
236+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
237+
238+
// Set whether or not to use GGML graph caching
239+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
235240

236241
#ifdef __cplusplus
237242
}

ggml/src/ggml-backend.c

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,13 @@ struct ggml_backend_sched_split {
10451045
struct ggml_cgraph graph;
10461046
};
10471047

1048+
// Object to facilitate GML graph caching
1049+
struct ggml_cached_graph {
1050+
bool is_active;
1051+
ggml_backend_t input_backend;
1052+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1053+
};
1054+
10481055
struct ggml_backend_sched {
10491056
bool is_reset; // true if the scheduler has been reset since the last graph split
10501057
bool is_alloc;
@@ -1090,6 +1097,8 @@ struct ggml_backend_sched {
10901097
size_t context_buffer_size;
10911098

10921099
bool debug;
1100+
1101+
struct ggml_cached_graph cached_graph;
10931102
};
10941103

10951104
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -1767,6 +1776,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17671776
struct ggml_tensor * input = split->inputs[j];
17681777
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
17691778

1779+
if (!sched->cached_graph.is_active) {
1780+
sched->cached_graph.input_backend = input_backend;
1781+
sched->cached_graph.input_cpy[j] = input_cpy;
1782+
}
1783+
else {
1784+
input_backend = sched->cached_graph.input_backend;
1785+
input_cpy = sched->cached_graph.input_cpy[j];
1786+
}
17701787
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17711788
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17721789
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1888,6 +1905,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18881905

18891906
ggml_backend_sched_reset(sched);
18901907

1908+
sched->cached_graph.is_active = false;
1909+
18911910
return sched;
18921911
}
18931912

@@ -1964,16 +1983,19 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19641983
}
19651984

19661985
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1967-
if (!sched->is_reset && !sched->is_alloc) {
1968-
ggml_backend_sched_reset(sched);
1969-
}
19701986

1971-
if (!sched->is_alloc) {
1972-
if (!ggml_backend_sched_alloc_graph(sched, graph)) {
1973-
return GGML_STATUS_ALLOC_FAILED;
1987+
if(!sched->cached_graph.is_active)
1988+
{
1989+
if (!sched->is_reset && !sched->is_alloc) {
1990+
ggml_backend_sched_reset(sched);
19741991
}
1975-
}
19761992

1993+
if (!sched->is_alloc) {
1994+
if (!ggml_backend_sched_alloc_graph(sched, graph)) {
1995+
return GGML_STATUS_ALLOC_FAILED;
1996+
}
1997+
}
1998+
}
19771999
return ggml_backend_sched_compute_splits(sched);
19782000
}
19792001

@@ -2238,3 +2260,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22382260

22392261
return true;
22402262
}
2263+
2264+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2265+
return sched->cached_graph.is_active;
2266+
}
2267+
2268+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2269+
sched->cached_graph.is_active = set_value;
2270+
}
2271+

src/llama.cpp

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,6 +2649,17 @@ struct llama_model {
26492649
}
26502650
};
26512651

2652+
// Object used to allow caching of GGML graph between tokens where possible.
2653+
struct ggml_cached_graph {
2654+
bool is_active = false;
2655+
ggml_cgraph * gf;
2656+
size_t n;
2657+
ggml_backend_t backend_res;
2658+
ggml_backend_t backend_embd;
2659+
struct ggml_tensor * res;
2660+
struct ggml_tensor * embd;
2661+
};
2662+
26522663
struct llama_context {
26532664
llama_context(const llama_model & model)
26542665
: model(model)
@@ -2749,6 +2760,8 @@ struct llama_context {
27492760
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
27502761
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
27512762
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
2763+
2764+
struct ggml_cached_graph cached_graph;
27522765
};
27532766

27542767
struct llama_lora_weight {
@@ -7886,7 +7899,6 @@ static void llm_build_kv_store(
78867899
v_cur = ggml_transpose(ctx, v_cur);
78877900
}
78887901
cb(v_cache_view, "v_cache_view", il);
7889-
78907902
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
78917903
}
78927904

@@ -14695,12 +14707,44 @@ static int llama_decode_internal(
1469514707
ggml_backend_sched_reset(lctx.sched);
1469614708
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1469714709

14698-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14699-
14710+
ggml_cgraph * gf;
1470014711
// the output is always the last tensor in the graph
14701-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14702-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14712+
struct ggml_tensor * res;
14713+
struct ggml_tensor * embd;
14714+
14715+
bool n_has_changed_since_last_token = false;
14716+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14717+
lctx.cached_graph.n = kv_self.n;
14718+
14719+
// Re-build graph only if graph caching is not possible
14720+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14721+
14722+
gf = llama_build_graph(lctx, u_batch, false);
14723+
14724+
// Set whether GGML graph caching is in use within GGML module, based on
14725+
// whether caching was activated here during the previous token
14726+
ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14727+
14728+
// Disable future graph caching in presence of env var,
14729+
// if there are multiple devices, if batch size is greater than 1,
14730+
// or if nsplits is not 2.
14731+
// TO DO enable graph caching for these cases
14732+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14733+
|| (llama_get_device_count(model) > 1)
14734+
|| (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14735+
for (int i = 0 ; i < gf->n_nodes; i++) {
14736+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14737+
disable_cached_ggml_graph = true;
14738+
break;
14739+
}
14740+
}
14741+
14742+
// Set whether graph caching should be used for future tokens
14743+
lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1470314744

14745+
// the output is always the last tensor in the graph
14746+
res = gf->nodes[gf->n_nodes - 1];
14747+
embd = gf->nodes[gf->n_nodes - 2];
1470414748
if (lctx.n_outputs == 0) {
1470514749
// no output
1470614750
res = nullptr;
@@ -14716,10 +14760,60 @@ static int llama_decode_internal(
1471614760
embd = nullptr; // do not extract embeddings when not needed
1471714761
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1471814762
}
14763+
lctx.cached_graph.res = res;
14764+
lctx.cached_graph.embd = embd;
1471914765
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1472014766

1472114767
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1472214768

14769+
}
14770+
else {
14771+
gf = lctx.cached_graph.gf;
14772+
res = lctx.cached_graph.res;
14773+
embd = lctx.cached_graph.embd;
14774+
}
14775+
lctx.cached_graph.gf = gf;
14776+
14777+
// Update K and V cache parameters in cached graph.
14778+
if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14779+
14780+
const struct llama_hparams & hparams = model.hparams;
14781+
const int64_t kv_head = kv_self.head;
14782+
14783+
for (int i = 0; i < gf->n_nodes; i++) {
14784+
ggml_tensor * node = gf->nodes[i];
14785+
if (node->op == GGML_OP_CPY) {
14786+
14787+
// K cache
14788+
const char* k_prefix = "k_cache_view-";
14789+
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14790+
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14791+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
14792+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14793+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14794+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14795+
}
14796+
14797+
// V cache
14798+
const char* v_prefix = "v_cache_view-";
14799+
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14800+
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14801+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14802+
ggml_tensor * tmp_tensor = kv_self.v_l[il];
14803+
size_t tmp_offset;
14804+
if (cparams.flash_attn) {
14805+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14806+
} else {
14807+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14808+
}
14809+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14810+
}
14811+
14812+
}
14813+
}
14814+
14815+
}
14816+
1472314817
llama_set_inputs(lctx, u_batch);
1472414818

1472514819
llama_graph_compute(lctx, gf, n_threads);
@@ -14742,11 +14836,15 @@ static int llama_decode_internal(
1474214836
// extract logits
1474314837
if (res) {
1474414838
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14745-
GGML_ASSERT(backend_res != nullptr);
14746-
GGML_ASSERT(lctx.logits != nullptr);
14747-
1474814839
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1474914840
const int32_t n_outputs_new = lctx.n_outputs;
14841+
if(!ggml_use_cached_graph(lctx.sched))
14842+
lctx.cached_graph.backend_res = backend_res;
14843+
else
14844+
backend_res = lctx.cached_graph.backend_res;
14845+
14846+
GGML_ASSERT(backend_res != nullptr);
14847+
GGML_ASSERT(lctx.logits != nullptr);
1475014848

1475114849
if (n_outputs_new) {
1475214850
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14758,6 +14856,12 @@ static int llama_decode_internal(
1475814856
// extract embeddings
1475914857
if (embd) {
1476014858
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14859+
14860+
14861+
if(!ggml_use_cached_graph(lctx.sched))
14862+
lctx.cached_graph.backend_embd = backend_embd;
14863+
else
14864+
backend_embd = lctx.cached_graph.backend_embd;
1476114865
GGML_ASSERT(backend_embd != nullptr);
1476214866

1476314867
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)