Skip to content

Commit 5df7d06

Browse files
authored
llama : allow exporting a view of the KV cache (#4180)
* Allow exporting a view of the KV cache * Allow dumping the sequences per cell in common * Track max contiguous cells value and position as well * Fix max contiguous empty cells index calculation Make dump functions deal with lengths or sequences counts > 10 better * Fix off by one error in dump_kv_cache_view * Add doc comments for KV cache view functions Eliminate cell sequence struct; use llama_seq_id directly Minor cleanups
1 parent 671f639 commit 5df7d06

File tree

5 files changed

+227
-0
lines changed

5 files changed

+227
-0
lines changed

common/common.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <regex>
1313
#include <sstream>
1414
#include <string>
15+
#include <unordered_map>
1516
#include <unordered_set>
1617
#include <vector>
1718
#include <cinttypes>
@@ -1386,3 +1387,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
13861387
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
13871388
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
13881389
}
1390+
1391+
//
1392+
// KV cache utils
1393+
//
1394+
1395+
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
1396+
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
1397+
1398+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
1399+
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1400+
1401+
llama_kv_cache_view_cell * c_curr = view.cells;
1402+
llama_seq_id * cs_curr = view.cells_sequences;
1403+
1404+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1405+
if (i % row_size == 0) {
1406+
printf("\n%5d: ", i);
1407+
}
1408+
int seq_count = 0;
1409+
for (int j = 0; j < view.n_max_seq; j++) {
1410+
if (cs_curr[j] >= 0) { seq_count++; }
1411+
}
1412+
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
1413+
}
1414+
1415+
printf("\n=== Done dumping\n");
1416+
}
1417+
1418+
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
1419+
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
1420+
1421+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
1422+
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1423+
1424+
std::unordered_map<llama_seq_id, size_t> seqs;
1425+
llama_kv_cache_view_cell * c_curr = view.cells;
1426+
llama_seq_id * cs_curr = view.cells_sequences;
1427+
1428+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1429+
for (int j = 0; j < view.n_max_seq; j++) {
1430+
if (cs_curr[j] < 0) { continue; }
1431+
if (seqs.find(cs_curr[j]) == seqs.end()) {
1432+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1433+
seqs[cs_curr[j]] = seqs.size();
1434+
}
1435+
}
1436+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1437+
}
1438+
1439+
printf("=== Sequence legend: ");
1440+
for (const auto & it : seqs) {
1441+
printf("%zu=%d, ", it.second, it.first);
1442+
}
1443+
printf("'+'=other sequence ids");
1444+
1445+
c_curr = view.cells;
1446+
cs_curr = view.cells_sequences;
1447+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1448+
if (i % row_size == 0) {
1449+
printf("\n%5d: ", i);
1450+
}
1451+
for (int j = 0; j < view.n_max_seq; j++) {
1452+
if (cs_curr[j] >= 0) {
1453+
const auto & it = seqs.find(cs_curr[j]);
1454+
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
1455+
} else {
1456+
putchar('.');
1457+
}
1458+
}
1459+
putchar(' ');
1460+
}
1461+
1462+
printf("\n=== Done dumping\n");
1463+
}

common/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,13 @@ std::string get_sortable_timestamp();
218218
void dump_non_result_info_yaml(
219219
FILE * stream, const gpt_params & params, const llama_context * lctx,
220220
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
221+
222+
//
223+
// KV cache utils
224+
//
225+
226+
// Dump the KV cache view with the number of sequences per cell.
227+
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
228+
229+
// Dump the KV cache view showing individual sequences in each cell (long output).
230+
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

examples/parallel/parallel.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ int main(int argc, char ** argv) {
172172
int32_t n_total_gen = 0;
173173
int32_t n_cache_miss = 0;
174174

175+
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
176+
175177
const auto t_main_start = ggml_time_us();
176178

177179
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
@@ -201,6 +203,9 @@ int main(int argc, char ** argv) {
201203
LOG_TEE("Processing requests ...\n\n");
202204

203205
while (true) {
206+
llama_kv_cache_view_update(ctx, &kvc_view);
207+
dump_kv_cache_view_seqs(kvc_view, 40);
208+
204209
llama_batch_clear(batch);
205210

206211
// decode any currently ongoing sequences

llama.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8805,6 +8805,95 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
88058805
}
88068806
}
88078807

8808+
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
8809+
struct llama_kv_cache_view result = {
8810+
/*.n_cells = */ 0,
8811+
/*.n_max_seq = */ n_max_seq,
8812+
/*.token_count = */ 0,
8813+
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
8814+
/*.max_contiguous = */ 0,
8815+
/*.max_contiguous_idx = */ -1,
8816+
/*.cells = */ nullptr,
8817+
/*.cells_sequences = */ nullptr,
8818+
};
8819+
return result;
8820+
}
8821+
8822+
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
8823+
if (view->cells != nullptr) {
8824+
free(view->cells);
8825+
view->cells = nullptr;
8826+
}
8827+
if (view->cells_sequences != nullptr) {
8828+
free(view->cells_sequences);
8829+
view->cells_sequences = nullptr;
8830+
}
8831+
}
8832+
8833+
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
8834+
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
8835+
view->n_cells = int32_t(ctx->kv_self.size);
8836+
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
8837+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
8838+
view->cells = (struct llama_kv_cache_view_cell *)p;
8839+
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
8840+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
8841+
view->cells_sequences = (llama_seq_id *)p;
8842+
}
8843+
8844+
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
8845+
llama_kv_cache_view_cell * c_curr = view->cells;
8846+
llama_seq_id * cs_curr = view->cells_sequences;
8847+
int32_t used_cells = 0;
8848+
int32_t token_count = 0;
8849+
int32_t curr_contig_idx = -1;
8850+
uint32_t max_contig = 0;
8851+
int32_t max_contig_idx = -1;
8852+
8853+
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
8854+
const size_t curr_size = kv_cells[i].seq_id.size();
8855+
token_count += curr_size;
8856+
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
8857+
8858+
if (curr_size > 0) {
8859+
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
8860+
max_contig = i - curr_contig_idx;
8861+
max_contig_idx = curr_contig_idx;
8862+
}
8863+
curr_contig_idx = -1;
8864+
} else if (curr_contig_idx < 0) {
8865+
curr_contig_idx = i;
8866+
}
8867+
8868+
int seq_idx = 0;
8869+
for (const llama_seq_id it : kv_cells[i].seq_id) {
8870+
if (seq_idx >= view->n_max_seq) {
8871+
break;
8872+
}
8873+
cs_curr[seq_idx] = it;
8874+
seq_idx++;
8875+
}
8876+
if (seq_idx != 0) {
8877+
used_cells++;
8878+
}
8879+
for (; seq_idx < view->n_max_seq; seq_idx++) {
8880+
cs_curr[seq_idx] = -1;
8881+
}
8882+
}
8883+
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
8884+
max_contig_idx = curr_contig_idx;
8885+
max_contig = kv_cells.size() - curr_contig_idx;
8886+
}
8887+
view->max_contiguous = max_contig;
8888+
view->max_contiguous_idx = max_contig_idx;
8889+
view->token_count = token_count;
8890+
view->used_cells = used_cells;
8891+
if (uint32_t(used_cells) != ctx->kv_self.used) {
8892+
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
8893+
__func__, ctx->kv_self.used, used_cells);
8894+
}
8895+
}
8896+
88088897
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
88098898
int result = 0;
88108899

llama.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,54 @@ extern "C" {
361361
// KV cache
362362
//
363363

364+
// Information associated with an individual cell in the KV cache view.
365+
struct llama_kv_cache_view_cell {
366+
// The position for this cell. Takes KV cache shifts into account.
367+
// May be negative if the cell is not populated.
368+
llama_pos pos;
369+
};
370+
371+
// An updateable view of the KV cache.
372+
struct llama_kv_cache_view {
373+
// Number of KV cache cells. This will be the same as the context size.
374+
int32_t n_cells;
375+
376+
// Maximum number of sequences that can exist in a cell. It's not an error
377+
// if there are more sequences in a cell than this value, however they will
378+
// not be visible in the view cells_sequences.
379+
int32_t n_max_seq;
380+
381+
// Number of tokens in the cache. For example, if there are two populated
382+
// cells, the first with 1 sequence id in it and the second with 2 sequence
383+
// ids then you'll have 3 tokens.
384+
int32_t token_count;
385+
386+
// Number of populated cache cells.
387+
int32_t used_cells;
388+
389+
// Maximum contiguous empty slots in the cache.
390+
int32_t max_contiguous;
391+
392+
// Index to the start of the max_contiguous slot range. Can be negative
393+
// when cache is full.
394+
int32_t max_contiguous_idx;
395+
396+
// Information for an individual cell.
397+
struct llama_kv_cache_view_cell * cells;
398+
399+
// The sequences for each cell. There will be n_max_seq items per cell.
400+
llama_seq_id * cells_sequences;
401+
};
402+
403+
// Create an empty KV cache view.
404+
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
405+
406+
// Free a KV cache view.
407+
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
408+
409+
// Update the KV cache view structure with the current state of the KV cache.
410+
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
411+
364412
// Returns the number of tokens in the KV cache (slow, use only for debug)
365413
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
366414
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);

0 commit comments

Comments
 (0)