Skip to content

Commit f2c9800

Browse files
committed
passkey : simplify n_past logic
1 parent bda3f2c commit f2c9800

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

examples/passkey/passkey.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,14 @@ int main(int argc, char ** argv) {
128128

129129
llama_batch batch = llama_batch_init(512, 0, 1);
130130

131+
int n_past = 0;
132+
131133
// fill the KV cache
132134
for (int i = 0; i < n_ctx; i += n_batch) {
133135
llama_batch_clear(batch);
134136

135137
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
136-
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false);
138+
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
137139
}
138140

139141
if (i + n_batch >= n_tokens_all) {
@@ -160,10 +162,12 @@ int main(int argc, char ** argv) {
160162
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
161163
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
162164

165+
n_past -= n_discard;
166+
163167
llama_batch_clear(batch);
164168

165169
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
166-
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false);
170+
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
167171
}
168172

169173
if (i + n_batch >= n_tokens_all) {
@@ -178,8 +182,6 @@ int main(int argc, char ** argv) {
178182
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
179183
}
180184

181-
int n_past = batch.pos[batch.n_tokens - 1];
182-
183185
{
184186
const int n_discard = n_past - n_ctx + n_predict;
185187

@@ -236,13 +238,12 @@ int main(int argc, char ** argv) {
236238
fflush(stdout);
237239

238240
n_decode += 1;
239-
n_past += 1;
240241

241242
// prepare the next batch
242243
llama_batch_clear(batch);
243244

244245
// push this new token for next evaluation
245-
llama_batch_add(batch, new_token_id, n_past, { 0 }, true);
246+
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
246247
}
247248

248249
n_cur += 1;

0 commit comments

Comments
 (0)