@@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
2156
2156
}
2157
2157
2158
2158
// find how many cells are currently in use
2159
- static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160
- for (uint32_t i = cache.size - 1; i > 0; --i) {
2161
- if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
2162
- return i + 1;
2159
+ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160
+ for (uint32_t i = cache.size; i > 0; --i) {
2161
+ const llama_kv_cell & cell = cache.cells[i - 1];
2162
+
2163
+ if (cell.pos >= 0 && !cell.is_empty()) {
2164
+ return i;
2163
2165
}
2164
2166
}
2165
2167
@@ -8178,7 +8180,7 @@ static int llama_decode_internal(
8178
8180
// a heuristic, to avoid attending the full cache if it is not yet utilized
8179
8181
// after enough generations, the benefit from this heuristic disappears
8180
8182
// if we start defragmenting the cache, the benefit from this will be more important
8181
- kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32 , GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8183
+ kv_self.n = std::min(cparams.n_ctx, std::max(32u , GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8182
8184
//kv_self.n = llama_kv_cache_cell_max(kv_self);
8183
8185
8184
8186
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
12615
12617
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
12616
12618
const size_t s_embedding_size = sizeof(size_t);
12617
12619
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
12618
- const size_t s_kv_size = sizeof(size_t);
12619
- const size_t s_kv_ntok = sizeof(int);
12620
+ const size_t s_kv_buf_size = sizeof(size_t);
12621
+ const size_t s_kv_head = sizeof(uint32_t);
12622
+ const size_t s_kv_size = sizeof(uint32_t);
12623
+ const size_t s_kv_used = sizeof(uint32_t);
12620
12624
const size_t s_kv = ctx->kv_self.total_size();
12625
+ // TODO: assume the max is more than 1 seq_id per KV cell
12626
+ const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
12627
+ const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
12621
12628
12622
12629
const size_t s_total = (
12623
12630
+ s_rng_size
@@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
12626
12633
+ s_logits
12627
12634
+ s_embedding_size
12628
12635
+ s_embedding
12636
+ + s_kv_buf_size
12637
+ + s_kv_head
12629
12638
+ s_kv_size
12630
- + s_kv_ntok
12639
+ + s_kv_used
12631
12640
+ s_kv
12641
+ + s_kv_cells
12632
12642
);
12633
12643
12634
12644
return s_total;
@@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12728
12738
{
12729
12739
const auto & kv_self = ctx->kv_self;
12730
12740
const auto & hparams = ctx->model.hparams;
12731
- const auto & cparams = ctx->cparams;
12732
12741
12733
12742
const uint32_t n_layer = hparams.n_layer;
12734
12743
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
12735
12744
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12736
- const uint32_t n_ctx = cparams.n_ctx;
12737
12745
12738
12746
const size_t kv_buf_size = kv_self.total_size();
12739
- const uint32_t kv_head = kv_self.head ;
12747
+ const uint32_t kv_head = llama_kv_cache_cell_max( kv_self) ;
12740
12748
const uint32_t kv_size = kv_self.size;
12741
12749
const uint32_t kv_used = kv_self.used;
12742
12750
@@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12756
12764
12757
12765
// v is not contiguous, copy row by row
12758
12766
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12759
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx );
12767
+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size );
12760
12768
12761
12769
tmp_buf.resize(v_row_size);
12762
12770
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
12766
12774
}
12767
12775
}
12768
12776
12769
- for (uint32_t i = 0; i < kv_size ; ++i) {
12777
+ for (uint32_t i = 0; i < kv_head ; ++i) {
12770
12778
const auto & cell = kv_self.cells[i];
12771
12779
12772
12780
const llama_pos pos = cell.pos;
@@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12842
12850
{
12843
12851
const auto & kv_self = ctx->kv_self;
12844
12852
const auto & hparams = ctx->model.hparams;
12845
- const auto & cparams = ctx->cparams;
12846
12853
12847
12854
const uint32_t n_layer = hparams.n_layer;
12848
12855
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
12849
12856
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12850
- const uint32_t n_ctx = cparams.n_ctx;
12851
12857
12852
12858
size_t kv_buf_size;
12853
12859
uint32_t kv_head;
@@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12870
12876
12871
12877
// v is not contiguous, copy row by row
12872
12878
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12873
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx );
12879
+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size );
12874
12880
12875
12881
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
12876
12882
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12879
12885
}
12880
12886
}
12881
12887
12888
+ GGML_ASSERT(kv_self.size == kv_size);
12889
+
12882
12890
ctx->kv_self.head = kv_head;
12883
12891
ctx->kv_self.size = kv_size;
12884
12892
ctx->kv_self.used = kv_used;
12885
12893
12886
12894
ctx->kv_self.cells.resize(kv_size);
12887
12895
12888
- for (uint32_t i = 0; i < kv_size ; ++i) {
12896
+ for (uint32_t i = 0; i < kv_head ; ++i) {
12889
12897
llama_pos pos;
12890
12898
size_t seq_id_size;
12891
12899
@@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
12901
12909
ctx->kv_self.cells[i].seq_id.insert(seq_id);
12902
12910
}
12903
12911
}
12912
+
12913
+ for (uint32_t i = kv_head; i < kv_size; ++i) {
12914
+ ctx->kv_self.cells[i].pos = -1;
12915
+ ctx->kv_self.cells[i].seq_id.clear();
12916
+ }
12904
12917
}
12905
12918
12906
12919
const size_t nread = inp - src;
0 commit comments