Skip to content

Commit de9692a

Browse files
authored
llama : fix llama_copy_state_data with fragmented KV cache (#5840)
The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems.
1 parent e602934 commit de9692a

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

llama.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
21562156
}
21572157

21582158
// find how many cells are currently in use
2159-
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160-
for (uint32_t i = cache.size - 1; i > 0; --i) {
2161-
if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
2162-
return i + 1;
2159+
static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160+
for (uint32_t i = cache.size; i > 0; --i) {
2161+
const llama_kv_cell & cell = cache.cells[i - 1];
2162+
2163+
if (cell.pos >= 0 && !cell.is_empty()) {
2164+
return i;
21632165
}
21642166
}
21652167

@@ -8178,7 +8180,7 @@ static int llama_decode_internal(
81788180
// a heuristic, to avoid attending the full cache if it is not yet utilized
81798181
// after enough generations, the benefit from this heuristic disappears
81808182
// if we start defragmenting the cache, the benefit from this will be more important
8181-
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8183+
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
81828184
//kv_self.n = llama_kv_cache_cell_max(kv_self);
81838185

81848186
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1261512617
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
1261612618
const size_t s_embedding_size = sizeof(size_t);
1261712619
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
12618-
const size_t s_kv_size = sizeof(size_t);
12619-
const size_t s_kv_ntok = sizeof(int);
12620+
const size_t s_kv_buf_size = sizeof(size_t);
12621+
const size_t s_kv_head = sizeof(uint32_t);
12622+
const size_t s_kv_size = sizeof(uint32_t);
12623+
const size_t s_kv_used = sizeof(uint32_t);
1262012624
const size_t s_kv = ctx->kv_self.total_size();
12625+
// TODO: assume the max is more than 1 seq_id per KV cell
12626+
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
12627+
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
1262112628

1262212629
const size_t s_total = (
1262312630
+ s_rng_size
@@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1262612633
+ s_logits
1262712634
+ s_embedding_size
1262812635
+ s_embedding
12636+
+ s_kv_buf_size
12637+
+ s_kv_head
1262912638
+ s_kv_size
12630-
+ s_kv_ntok
12639+
+ s_kv_used
1263112640
+ s_kv
12641+
+ s_kv_cells
1263212642
);
1263312643

1263412644
return s_total;
@@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1272812738
{
1272912739
const auto & kv_self = ctx->kv_self;
1273012740
const auto & hparams = ctx->model.hparams;
12731-
const auto & cparams = ctx->cparams;
1273212741

1273312742
const uint32_t n_layer = hparams.n_layer;
1273412743
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1273512744
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12736-
const uint32_t n_ctx = cparams.n_ctx;
1273712745

1273812746
const size_t kv_buf_size = kv_self.total_size();
12739-
const uint32_t kv_head = kv_self.head;
12747+
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
1274012748
const uint32_t kv_size = kv_self.size;
1274112749
const uint32_t kv_used = kv_self.used;
1274212750

@@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1275612764

1275712765
// v is not contiguous, copy row by row
1275812766
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12759-
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
12767+
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
1276012768

1276112769
tmp_buf.resize(v_row_size);
1276212770
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1276612774
}
1276712775
}
1276812776

12769-
for (uint32_t i = 0; i < kv_size; ++i) {
12777+
for (uint32_t i = 0; i < kv_head; ++i) {
1277012778
const auto & cell = kv_self.cells[i];
1277112779

1277212780
const llama_pos pos = cell.pos;
@@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1284212850
{
1284312851
const auto & kv_self = ctx->kv_self;
1284412852
const auto & hparams = ctx->model.hparams;
12845-
const auto & cparams = ctx->cparams;
1284612853

1284712854
const uint32_t n_layer = hparams.n_layer;
1284812855
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1284912856
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12850-
const uint32_t n_ctx = cparams.n_ctx;
1285112857

1285212858
size_t kv_buf_size;
1285312859
uint32_t kv_head;
@@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287012876

1287112877
// v is not contiguous, copy row by row
1287212878
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12873-
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
12879+
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
1287412880

1287512881
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
1287612882
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287912885
}
1288012886
}
1288112887

12888+
GGML_ASSERT(kv_self.size == kv_size);
12889+
1288212890
ctx->kv_self.head = kv_head;
1288312891
ctx->kv_self.size = kv_size;
1288412892
ctx->kv_self.used = kv_used;
1288512893

1288612894
ctx->kv_self.cells.resize(kv_size);
1288712895

12888-
for (uint32_t i = 0; i < kv_size; ++i) {
12896+
for (uint32_t i = 0; i < kv_head; ++i) {
1288912897
llama_pos pos;
1289012898
size_t seq_id_size;
1289112899

@@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1290112909
ctx->kv_self.cells[i].seq_id.insert(seq_id);
1290212910
}
1290312911
}
12912+
12913+
for (uint32_t i = kv_head; i < kv_size; ++i) {
12914+
ctx->kv_self.cells[i].pos = -1;
12915+
ctx->kv_self.cells[i].seq_id.clear();
12916+
}
1290412917
}
1290512918

1290612919
const size_t nread = inp - src;

0 commit comments

Comments
 (0)