@@ -367,17 +367,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
367
367
return {tokens, -1 , logit_history, prob_history};
368
368
}
369
369
370
- const int calc_chunk = n_ctx;
370
+ fprintf (stderr, " %s: have %zu tokens. Calculation chunk = %d \n " , __func__, tokens. size (), n_ctx) ;
371
371
372
- fprintf (stderr, " %s: have %zu tokens. Calculation chunk = %d\n " , __func__, tokens.size (), calc_chunk);
373
-
374
- if (int (tokens.size ()) <= calc_chunk) {
372
+ if (int (tokens.size ()) <= n_ctx) {
375
373
fprintf (stderr, " %s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n " ,__func__,
376
374
tokens.size (), n_ctx, params.ppl_stride );
377
375
return {tokens, -1 , logit_history, prob_history};
378
376
}
379
377
380
- const int n_chunk_max = (tokens.size () - calc_chunk + params.ppl_stride - 1 ) / params.ppl_stride ;
378
+ const int n_chunk_max = (tokens.size () - n_ctx + params.ppl_stride - 1 ) / params.ppl_stride ;
381
379
382
380
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
383
381
const int n_vocab = llama_n_vocab (llama_get_model (ctx));
@@ -386,13 +384,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
386
384
int count = 0 ;
387
385
double nll = 0.0 ;
388
386
387
+ const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
388
+
389
389
fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
390
390
391
391
for (int i = 0 ; i < n_chunk; ++i) {
392
392
const int start = i * params.ppl_stride ;
393
- const int end = start + calc_chunk;
394
-
395
- const int num_batches = (calc_chunk + n_batch - 1 ) / n_batch;
393
+ const int end = start + n_ctx;
396
394
// fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
397
395
398
396
std::vector<float > logits;
@@ -406,13 +404,27 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
406
404
const int batch_start = start + j * n_batch;
407
405
const int batch_size = std::min (end - batch_start, n_batch);
408
406
407
+ llama_batch batch = llama_batch_init (batch_size, 0 , 1 );
408
+ for (int k = 0 ; k < batch_size; ++k) {
409
+ const int idx = batch_start + k;
410
+ batch.token [k] = tokens[idx];
411
+ batch.output [k] = 1 ;
412
+ }
413
+ batch.n_tokens = batch_size;
414
+ batch.pos = nullptr ;
415
+ batch.n_seq_id = nullptr ;
416
+ batch.seq_id = nullptr ;
417
+ batch.all_pos_0 = j*n_batch;
418
+ batch.all_pos_1 = 1 ;
419
+ batch.all_seq_id = 0 ;
420
+
409
421
// fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
410
- // TODO: use llama_batch.output instead of relying on logits_all == true
411
- if (llama_decode (ctx, llama_batch_get_one (tokens.data () + batch_start, batch_size, j * n_batch, 0 ))) {
422
+ if (llama_decode (ctx, batch)) {
412
423
// fprintf(stderr, "%s : failed to eval\n", __func__);
413
424
return {tokens, -1 , logit_history, prob_history};
414
425
}
415
426
427
+ llama_batch_free (batch);
416
428
// save original token and restore it after eval
417
429
const auto token_org = tokens[batch_start];
418
430
0 commit comments