Skip to content

Commit ef10add

Browse files
agray3Nexesenex
authored andcommitted
ggml: avoid rebuild of GGML graph for each token (ggml-org#7456)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable. fix seg fault restrict to nsplit=2 Improve identification of K and V nodes for param updates
1 parent 78fac97 commit ef10add

File tree

5 files changed

+181
-13
lines changed

5 files changed

+181
-13
lines changed

ggml/include/ggml-backend.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ extern "C" {
232232
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
233233
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
234234

235+
// Utility to query whether cached GGML graph is in use
236+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
237+
238+
// Set whether or not to use GGML graph caching
239+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
235240

236241
#ifdef __cplusplus
237242
}

ggml/include/ggml.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,13 @@ extern "C" {
570570
GGML_TENSOR_FLAG_PARAM = 4,
571571
};
572572

573+
// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
574+
enum ggml_kv_cache_flag {
575+
GGML_KV_CACHE_FLAG_NONE = 0,
576+
GGML_KV_CACHE_FLAG_K = 1,
577+
GGML_KV_CACHE_FLAG_V = 2
578+
};
579+
573580
// ggml object
574581
struct ggml_object {
575582
size_t offs;
@@ -604,6 +611,8 @@ extern "C" {
604611
// op params - allocated as int32_t for alignment
605612
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
606613

614+
enum ggml_kv_cache_flag kv_cache_flag;
615+
607616
int32_t flags;
608617

609618
struct ggml_tensor * grad;
@@ -619,7 +628,7 @@ extern "C" {
619628

620629
void * extra; // extra things e.g. for ggml-cuda.cu
621630

622-
// char padding[4];
631+
char padding[1];
623632
};
624633

625634
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml-backend.c

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,13 @@ struct ggml_backend_sched_split {
10451045
struct ggml_cgraph graph;
10461046
};
10471047

1048+
// Object to facilitate GML graph caching
1049+
struct ggml_cached_graph {
1050+
bool is_active;
1051+
ggml_backend_t input_backend;
1052+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1053+
};
1054+
10481055
struct ggml_backend_sched {
10491056
bool is_reset; // true if the scheduler has been reset since the last graph split
10501057
bool is_alloc;
@@ -1090,6 +1097,18 @@ struct ggml_backend_sched {
10901097
size_t context_buffer_size;
10911098

10921099
bool debug;
1100+
1101+
// align context_buffer to GGML_MEM_ALIGN
1102+
1103+
// #ifdef _MSC_VER
1104+
// __declspec(align(GGML_MEM_ALIGN))
1105+
// #else
1106+
// __attribute__((aligned(GGML_MEM_ALIGN)))
1107+
// #endif
1108+
1109+
// char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
1110+
1111+
// struct ggml_cached_graph cached_graph;
10931112
};
10941113

10951114
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -1767,6 +1786,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17671786
struct ggml_tensor * input = split->inputs[j];
17681787
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
17691788

1789+
if (!sched->cached_graph.is_active) {
1790+
sched->cached_graph.input_backend = input_backend;
1791+
sched->cached_graph.input_cpy[j] = input_cpy;
1792+
}
1793+
else {
1794+
input_backend = sched->cached_graph.input_backend;
1795+
input_cpy = sched->cached_graph.input_cpy[j];
1796+
}
17701797
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17711798
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17721799
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1888,6 +1915,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18881915

18891916
ggml_backend_sched_reset(sched);
18901917

1918+
sched->cached_graph.is_active = false;
1919+
18911920
return sched;
18921921
}
18931922

@@ -1964,6 +1993,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19641993
}
19651994

19661995
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1996+
1997+
if(!sched->cached_graph.is_active)
1998+
{
19671999
if (!sched->is_reset && !sched->is_alloc) {
19682000
ggml_backend_sched_reset(sched);
19692001
}
@@ -1973,7 +2005,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
19732005
return GGML_STATUS_ALLOC_FAILED;
19742006
}
19752007
}
1976-
2008+
}
19772009
return ggml_backend_sched_compute_splits(sched);
19782010
}
19792011

@@ -2238,3 +2270,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22382270

22392271
return true;
22402272
}
2273+
2274+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2275+
return sched->cached_graph.is_active;
2276+
}
2277+
2278+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2279+
sched->cached_graph.is_active = set_value;
2280+
}
2281+

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3770,6 +3770,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
37703770
/*.nb =*/ { 0, 0, 0, 0 },
37713771
/*.op =*/ GGML_OP_NONE,
37723772
/*.op_params =*/ { 0 },
3773+
/*.kv_cache_flag=*/ GGML_KV_CACHE_FLAG_NONE,
37733774
/*.flags =*/ 0,
37743775
/*.grad =*/ NULL,
37753776
/*.src =*/ { NULL },
@@ -3778,7 +3779,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
37783779
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
37793780
/*.name =*/ { 0 },
37803781
/*.extra =*/ NULL,
3781-
///*.padding =*/ { 0 },
3782+
/*.padding =*/ { 0 },
37823783
};
37833784

37843785
#ifdef __clang__

src/llama.cpp

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,6 +2648,17 @@ struct llama_model {
26482648
}
26492649
};
26502650

2651+
// Object used to allow caching of GGML graph between tokens where possible.
2652+
struct ggml_cached_graph {
2653+
bool is_active = false;
2654+
ggml_cgraph * gf;
2655+
size_t n;
2656+
ggml_backend_t backend_res;
2657+
ggml_backend_t backend_embd;
2658+
struct ggml_tensor * res;
2659+
struct ggml_tensor * embd;
2660+
};
2661+
26512662
struct llama_context {
26522663
llama_context(const llama_model & model)
26532664
: model(model)
@@ -2748,6 +2759,10 @@ struct llama_context {
27482759
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
27492760
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
27502761
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
2762+
2763+
// cached Cuda Graphs
2764+
struct ggml_cached_graph cached_graph;
2765+
27512766
};
27522767

27532768
struct llama_lora_weight {
@@ -7902,7 +7917,9 @@ static void llm_build_kv_store(
79027917
cb(k_cache_view, "k_cache_view", il);
79037918

79047919
// note: storing RoPE-ed version of K in the KV cache
7905-
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7920+
ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7921+
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7922+
ggml_build_forward_expand(graph, tmp);
79067923

79077924
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
79087925

@@ -7920,8 +7937,9 @@ static void llm_build_kv_store(
79207937
v_cur = ggml_transpose(ctx, v_cur);
79217938
}
79227939
cb(v_cache_view, "v_cache_view", il);
7923-
7924-
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7940+
tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7941+
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7942+
ggml_build_forward_expand(graph, tmp);
79257943
}
79267944

79277945
// do mat_mul, while optionally apply lora
@@ -14729,12 +14747,44 @@ static int llama_decode_internal(
1472914747
ggml_backend_sched_reset(lctx.sched);
1473014748
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1473114749

14732-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14733-
14750+
ggml_cgraph * gf;
1473414751
// the output is always the last tensor in the graph
14735-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14736-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14752+
struct ggml_tensor * res;
14753+
struct ggml_tensor * embd;
14754+
14755+
bool n_has_changed_since_last_token = false;
14756+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14757+
lctx.cached_graph.n = kv_self.n;
14758+
14759+
// Re-build graph only if graph caching is not possible
14760+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14761+
14762+
gf = llama_build_graph(lctx, u_batch, false);
14763+
14764+
// Set whether GGML graph caching is in use within GGML module, based on
14765+
// whether caching was activated here during the previous token
14766+
ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14767+
14768+
// Disable future graph caching in presence of env var,
14769+
// if there are multiple devices, if batch size is greater than 1,
14770+
// or if nsplits is not 2.
14771+
// TO DO enable graph caching for these cases
14772+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14773+
|| (llama_get_device_count(model) > 1)
14774+
|| (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14775+
for (int i = 0 ; i < gf->n_nodes; i++) {
14776+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14777+
disable_cached_ggml_graph = true;
14778+
break;
14779+
}
14780+
}
14781+
14782+
// Set whether graph caching should be used for future tokens
14783+
lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1473714784

14785+
// the output is always the last tensor in the graph
14786+
res = gf->nodes[gf->n_nodes - 1];
14787+
embd = gf->nodes[gf->n_nodes - 2];
1473814788
if (lctx.n_outputs == 0) {
1473914789
// no output
1474014790
res = nullptr;
@@ -14750,10 +14800,62 @@ static int llama_decode_internal(
1475014800
embd = nullptr; // do not extract embeddings when not needed
1475114801
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1475214802
}
14803+
lctx.cached_graph.res = res;
14804+
lctx.cached_graph.embd = embd;
1475314805
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1475414806

1475514807
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1475614808

14809+
}
14810+
else {
14811+
gf = lctx.cached_graph.gf;
14812+
res = lctx.cached_graph.res;
14813+
embd = lctx.cached_graph.embd;
14814+
}
14815+
lctx.cached_graph.gf = gf;
14816+
14817+
if(ggml_use_cached_graph(lctx.sched)) {
14818+
14819+
// Temporarily store KV cache parameters that will need updated in cached graph.
14820+
const struct llama_hparams & hparams = model.hparams;
14821+
const int64_t n_layer = hparams.n_layer;
14822+
const int64_t kv_head = kv_self.head;
14823+
std::vector<void *> k_cache_ptrs;
14824+
std::vector<void *> v_cache_ptrs;
14825+
for (int il = 0; il < n_layer; ++il) {
14826+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
14827+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14828+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14829+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14830+
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14831+
tmp_tensor = kv_self.v_l[il];
14832+
if (cparams.flash_attn) {
14833+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14834+
} else {
14835+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14836+
}
14837+
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14838+
}
14839+
14840+
// Update KV cache parameters in cached graph.
14841+
int k_count = 0;
14842+
int v_count = 0;
14843+
if(gf != nullptr && gf->nodes != nullptr){
14844+
for (int i = 0; i < gf->n_nodes; i++) {
14845+
ggml_tensor * node = gf->nodes[i];
14846+
if (node->op == GGML_OP_CPY) {
14847+
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14848+
node->src[1]->data = k_cache_ptrs[k_count++];
14849+
}
14850+
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14851+
node->src[1]->data = v_cache_ptrs[v_count++];
14852+
}
14853+
}
14854+
}
14855+
}
14856+
14857+
}
14858+
1475714859
llama_set_inputs(lctx, u_batch);
1475814860

1475914861
llama_graph_compute(lctx, gf, n_threads);
@@ -14776,11 +14878,15 @@ static int llama_decode_internal(
1477614878
// extract logits
1477714879
if (res) {
1477814880
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14779-
GGML_ASSERT(backend_res != nullptr);
14780-
GGML_ASSERT(lctx.logits != nullptr);
14781-
1478214881
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1478314882
const int32_t n_outputs_new = lctx.n_outputs;
14883+
if(!ggml_use_cached_graph(lctx.sched))
14884+
lctx.cached_graph.backend_res = backend_res;
14885+
else
14886+
backend_res = lctx.cached_graph.backend_res;
14887+
14888+
GGML_ASSERT(backend_res != nullptr);
14889+
GGML_ASSERT(lctx.logits != nullptr);
1478414890

1478514891
if (n_outputs_new) {
1478614892
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14792,6 +14898,12 @@ static int llama_decode_internal(
1479214898
// extract embeddings
1479314899
if (embd) {
1479414900
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14901+
14902+
14903+
if(!ggml_use_cached_graph(lctx.sched))
14904+
lctx.cached_graph.backend_embd = backend_embd;
14905+
else
14906+
backend_embd = lctx.cached_graph.backend_embd;
1479514907
GGML_ASSERT(backend_embd != nullptr);
1479614908

1479714909
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)