Skip to content

Commit acaa5c3

Browse files
agray3Nexesenex
authored andcommitted
ggml: avoid rebuild of GGML graph for each token (ggml-org#7456)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable. fix seg fault
1 parent 8f980e1 commit acaa5c3

File tree

3 files changed

+158
-8
lines changed

3 files changed

+158
-8
lines changed

ggml-backend.c

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,13 @@ struct ggml_backend_sched_split {
10361036
struct ggml_cgraph graph;
10371037
};
10381038

1039+
// Object to facilitate GML graph caching
1040+
struct ggml_cached_graph {
1041+
bool is_active;
1042+
ggml_backend_t input_backend;
1043+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1044+
};
1045+
10391046
struct ggml_backend_sched {
10401047
bool is_reset; // true if the scheduler has been reset since the last graph split
10411048
bool is_alloc;
@@ -1087,6 +1094,8 @@ struct ggml_backend_sched {
10871094
__attribute__((aligned(GGML_MEM_ALIGN)))
10881095
#endif
10891096
char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
1097+
1098+
struct ggml_cached_graph cached_graph;
10901099
};
10911100

10921101
#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
@@ -1753,6 +1762,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17531762
struct ggml_tensor * input = split->inputs[j];
17541763
struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
17551764

1765+
if (!sched->cached_graph.is_active) {
1766+
sched->cached_graph.input_backend = input_backend;
1767+
sched->cached_graph.input_cpy[j] = input_cpy;
1768+
}
1769+
else {
1770+
input_backend = sched->cached_graph.input_backend;
1771+
input_cpy = sched->cached_graph.input_cpy[j];
1772+
}
17561773
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17571774
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17581775
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1872,6 +1889,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18721889

18731890
ggml_backend_sched_reset(sched);
18741891

1892+
sched->cached_graph.is_active = false;
1893+
18751894
return sched;
18761895
}
18771896

@@ -1947,6 +1966,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19471966
}
19481967

19491968
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1969+
1970+
if(!sched->cached_graph.is_active)
1971+
{
19501972
if (!sched->is_reset && !sched->is_alloc) {
19511973
ggml_backend_sched_reset(sched);
19521974
}
@@ -1956,7 +1978,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
19561978
return GGML_STATUS_ALLOC_FAILED;
19571979
}
19581980
}
1959-
1981+
}
19601982
return ggml_backend_sched_compute_splits(sched);
19611983
}
19621984

@@ -2223,3 +2245,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22232245

22242246
return true;
22252247
}
2248+
2249+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2250+
return sched->cached_graph.is_active;
2251+
}
2252+
2253+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2254+
sched->cached_graph.is_active = set_value;
2255+
}
2256+

ggml-backend.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ extern "C" {
230230
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
231231
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
232232

233+
// Utility to query whether cached GGML graph is in use
234+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
235+
236+
// Set whether or not to use GGML graph caching
237+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
233238

234239
#ifdef __cplusplus
235240
}

llama.cpp

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2745,6 +2745,17 @@ struct llama_model {
27452745
}
27462746
};
27472747

2748+
// Object used to allow caching of GGML graph between tokens where possible.
2749+
struct ggml_cached_graph {
2750+
bool is_active = false;
2751+
ggml_cgraph * gf;
2752+
size_t n;
2753+
ggml_backend_t backend_res;
2754+
ggml_backend_t backend_embd;
2755+
struct ggml_tensor * res;
2756+
struct ggml_tensor * embd;
2757+
};
2758+
27482759
struct llama_context {
27492760
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
27502761
~llama_context() {
@@ -2846,6 +2857,8 @@ struct llama_context {
28462857

28472858
// control vectors
28482859
struct llama_control_vector cvec;
2860+
2861+
struct ggml_cached_graph cached_graph;
28492862
};
28502863

28512864
static size_t llama_get_device_count(const llama_model & model) {
@@ -14668,12 +14681,42 @@ static int llama_decode_internal(
1466814681
ggml_backend_sched_reset(lctx.sched);
1466914682
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1467014683

14671-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14672-
14684+
ggml_cgraph * gf;
1467314685
// the output is always the last tensor in the graph
14674-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14675-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14686+
struct ggml_tensor * res;
14687+
struct ggml_tensor * embd;
14688+
14689+
bool n_has_changed_since_last_token = false;
14690+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14691+
lctx.cached_graph.n = kv_self.n;
14692+
14693+
// Re-build graph only if graph caching is not possible
14694+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14695+
14696+
gf = llama_build_graph(lctx, u_batch, false);
14697+
14698+
// Set whether GGML graph caching is in use within GGML module, based on
14699+
// whether caching was activated here during the previous token
14700+
ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14701+
14702+
// Disable future graph caching in presence of env var,
14703+
// if there are multiple devices, or if batch size is greater than 1
14704+
// TO DO enable graph caching for these cases
14705+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14706+
|| (llama_get_device_count(model) > 1);
14707+
for (int i = 0 ; i < gf->n_nodes; i++) {
14708+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14709+
disable_cached_ggml_graph = true;
14710+
break;
14711+
}
14712+
}
14713+
14714+
// Set whether graph caching should be used for future tokens
14715+
lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1467614716

14717+
// the output is always the last tensor in the graph
14718+
res = gf->nodes[gf->n_nodes - 1];
14719+
embd = gf->nodes[gf->n_nodes - 2];
1467714720
if (lctx.n_outputs == 0) {
1467814721
// no output
1467914722
res = nullptr;
@@ -14689,10 +14732,71 @@ static int llama_decode_internal(
1468914732
embd = nullptr; // do not extract embeddings when not needed
1469014733
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1469114734
}
14735+
lctx.cached_graph.res = res;
14736+
lctx.cached_graph.embd = embd;
1469214737
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1469314738

1469414739
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1469514740

14741+
}
14742+
else {
14743+
gf = lctx.cached_graph.gf;
14744+
res = lctx.cached_graph.res;
14745+
embd = lctx.cached_graph.embd;
14746+
}
14747+
lctx.cached_graph.gf = gf;
14748+
14749+
if(ggml_use_cached_graph(lctx.sched)) {
14750+
14751+
// If using flash attention, find mask node so it can be skipped when updating
14752+
// KV cache paramaters in cached graph nodes below
14753+
void * flash_attn_mask_node = nullptr;
14754+
if(cparams.flash_attn) {
14755+
for (int i = 0; i < gf->n_nodes; i++) {
14756+
ggml_tensor * node = gf->nodes[i];
14757+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14758+
flash_attn_mask_node = node->src[3];
14759+
break;
14760+
}
14761+
}
14762+
}
14763+
14764+
// Temporarily store KV cache parameters that will need updated in cached graph.
14765+
const struct llama_hparams & hparams = model.hparams;
14766+
const int64_t n_layer = hparams.n_layer;
14767+
const int64_t kv_head = kv_self.head;
14768+
std::vector<void *> kv_cache_ptrs;
14769+
for (int il = 0; il < n_layer; ++il) {
14770+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14771+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14772+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14773+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14774+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14775+
tmp_tensor = kv_self.v_l[il];
14776+
if (cparams.flash_attn) {
14777+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14778+
} else {
14779+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14780+
}
14781+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14782+
}
14783+
14784+
// Update KV cache parameters in cached graph.
14785+
int copy_op_count = 0;
14786+
if(gf != nullptr && gf->nodes != nullptr){
14787+
for (int i = 0; i < gf->n_nodes; i++) {
14788+
ggml_tensor * node = gf->nodes[i];
14789+
if (node->op == GGML_OP_CPY) {
14790+
if (node != flash_attn_mask_node) {
14791+
node->src[1]->data = kv_cache_ptrs[copy_op_count];
14792+
copy_op_count++;
14793+
}
14794+
}
14795+
}
14796+
}
14797+
14798+
}
14799+
1469614800
llama_set_inputs(lctx, u_batch);
1469714801

1469814802
llama_graph_compute(lctx, gf, n_threads);
@@ -14715,11 +14819,15 @@ static int llama_decode_internal(
1471514819
// extract logits
1471614820
if (res) {
1471714821
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14718-
GGML_ASSERT(backend_res != nullptr);
14719-
GGML_ASSERT(lctx.logits != nullptr);
14720-
1472114822
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1472214823
const int32_t n_outputs_new = lctx.n_outputs;
14824+
if(!ggml_use_cached_graph(lctx.sched))
14825+
lctx.cached_graph.backend_res = backend_res;
14826+
else
14827+
backend_res = lctx.cached_graph.backend_res;
14828+
14829+
GGML_ASSERT(backend_res != nullptr);
14830+
GGML_ASSERT(lctx.logits != nullptr);
1472314831

1472414832
if (n_outputs_new) {
1472514833
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14731,6 +14839,12 @@ static int llama_decode_internal(
1473114839
// extract embeddings
1473214840
if (embd) {
1473314841
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14842+
14843+
14844+
if(!ggml_use_cached_graph(lctx.sched))
14845+
lctx.cached_graph.backend_embd = backend_embd;
14846+
else
14847+
backend_embd = lctx.cached_graph.backend_embd;
1473414848
GGML_ASSERT(backend_embd != nullptr);
1473514849

1473614850
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)