@@ -128,12 +128,14 @@ int main(int argc, char ** argv) {
128
128
129
129
llama_batch batch = llama_batch_init (512 , 0 , 1 );
130
130
131
+ int n_past = 0 ;
132
+
131
133
// fill the KV cache
132
134
for (int i = 0 ; i < n_ctx; i += n_batch) {
133
135
llama_batch_clear (batch);
134
136
135
137
for (int j = 0 ; j < n_batch && i + j < n_tokens_all; j++) {
136
- llama_batch_add (batch, tokens_list[i + j], i + j , { 0 }, false );
138
+ llama_batch_add (batch, tokens_list[i + j], n_past++ , { 0 }, false );
137
139
}
138
140
139
141
if (i + n_batch >= n_tokens_all) {
@@ -160,10 +162,12 @@ int main(int argc, char ** argv) {
160
162
llama_kv_cache_seq_rm (ctx, 0 , n_keep , n_keep + n_discard);
161
163
llama_kv_cache_seq_shift (ctx, 0 , n_keep + n_discard, n_ctx, -n_discard);
162
164
165
+ n_past -= n_discard;
166
+
163
167
llama_batch_clear (batch);
164
168
165
169
for (int j = 0 ; j < n_batch && i + j < n_tokens_all; j++) {
166
- llama_batch_add (batch, tokens_list[i + j], n_ctx - n_discard + j , { 0 }, false );
170
+ llama_batch_add (batch, tokens_list[i + j], n_past++ , { 0 }, false );
167
171
}
168
172
169
173
if (i + n_batch >= n_tokens_all) {
@@ -178,8 +182,6 @@ int main(int argc, char ** argv) {
178
182
LOG_TEE (" %s: processed: [%6d, %6d)\n " , __func__, i, std::min (i + n_batch, n_tokens_all));
179
183
}
180
184
181
- int n_past = batch.pos [batch.n_tokens - 1 ];
182
-
183
185
{
184
186
const int n_discard = n_past - n_ctx + n_predict;
185
187
@@ -236,13 +238,12 @@ int main(int argc, char ** argv) {
236
238
fflush (stdout);
237
239
238
240
n_decode += 1 ;
239
- n_past += 1 ;
240
241
241
242
// prepare the next batch
242
243
llama_batch_clear (batch);
243
244
244
245
// push this new token for next evaluation
245
- llama_batch_add (batch, new_token_id, n_past, { 0 }, true );
246
+ llama_batch_add (batch, new_token_id, n_past++ , { 0 }, true );
246
247
}
247
248
248
249
n_cur += 1 ;
0 commit comments