@@ -9036,8 +9036,8 @@ static int llama_decode_internal(
9036
9036
//llama_synchronize(&lctx);
9037
9037
9038
9038
// decide if we need to defrag the kv cache
9039
- if (cparams.defrag_thold >= 0.0f) {
9040
- const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens_all )/float(kv_self.n) : 0.0f;
9039
+ if (cparams.causal_attn && cparams. defrag_thold >= 0.0f) {
9040
+ const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
9041
9041
9042
9042
// queue defragmentation for next llama_kv_cache_update
9043
9043
if (fragmentation > cparams.defrag_thold) {
@@ -9069,6 +9069,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
9069
9069
// number of cells moved
9070
9070
uint32_t n_moves = 0;
9071
9071
9072
+ // each move requires 6*n_layer tensors (see build_defrag)
9073
+ // - source view, destination view, copy operation
9074
+ // - x2 for keys and values
9075
+ const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
9076
+
9072
9077
// determine which KV cells to move where
9073
9078
//
9074
9079
// cell i moves to ids[i]
@@ -9095,15 +9100,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
9095
9100
nh++;
9096
9101
}
9097
9102
9098
- // each move requires 6*n_layer tensors (see build_defrag)
9099
- // - source view, destination view, copy operation
9100
- // - x2 for keys and values
9101
- //
9102
- if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
9103
- // the graph is too big, we cannot move more cells
9104
- break;
9105
- }
9106
-
9107
9103
uint32_t nf = 0;
9108
9104
uint32_t is = n_kv - 1;
9109
9105
@@ -9133,11 +9129,19 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
9133
9129
// are we moving a continuous block of memory?
9134
9130
bool cont = false;
9135
9131
9132
+ // should we stop searching for the next move?
9133
+ bool stop = false;
9134
+
9136
9135
// go back and move the nf cells to the hole
9137
9136
for (; i1 < n_kv; ++i1) {
9138
9137
auto & cell1 = kv_self.cells[i1];
9139
9138
9140
9139
if (cell1.is_empty() || ids[i1] != n_kv) {
9140
+ if (n_moves == max_moves) {
9141
+ stop = true;
9142
+ break;
9143
+ }
9144
+
9141
9145
cont = false;
9142
9146
continue;
9143
9147
}
@@ -9164,6 +9168,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
9164
9168
}
9165
9169
}
9166
9170
9171
+ if (stop || n_moves == max_moves) {
9172
+ break;
9173
+ }
9174
+
9167
9175
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
9168
9176
9169
9177
i0 += nh - 1;
0 commit comments