@@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
301
301
return true ;
302
302
}
303
303
304
- gpt_vocab::id gpt_sample_top_k_top_p (
305
- const gpt_vocab & vocab,
306
- const float * logits,
307
- int top_k,
308
- double top_p,
309
- double temp,
310
- std::mt19937 & rng) {
311
- int n_logits = vocab.id_to_token .size ();
312
-
313
- std::vector<std::pair<double , gpt_vocab::id>> logits_id;
314
- logits_id.reserve (n_logits);
315
-
316
- {
317
- const double scale = 1.0 /temp;
318
- for (int i = 0 ; i < n_logits; ++i) {
319
- logits_id.push_back (std::make_pair (logits[i]*scale, i));
320
- }
321
- }
322
304
305
+ void sample_top_k (std::vector<std::pair<double , gpt_vocab::id>> & logits_id, int top_k) {
323
306
// find the top K tokens
324
307
std::partial_sort (
325
308
logits_id.begin (),
@@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
329
312
});
330
313
331
314
logits_id.resize (top_k);
332
-
333
- double maxl = -INFINITY;
334
- for (const auto & kv : logits_id) {
335
- maxl = std::max (maxl, kv.first );
336
- }
337
-
338
- // compute probs for the top K tokens
339
- std::vector<double > probs;
340
- probs.reserve (logits_id.size ());
341
-
342
- double sum = 0.0 ;
343
- for (const auto & kv : logits_id) {
344
- double p = exp (kv.first - maxl);
345
- probs.push_back (p);
346
- sum += p;
347
- }
348
-
349
- // normalize the probs
350
- for (auto & p : probs) {
351
- p /= sum;
352
- }
353
-
354
- if (top_p < 1 .0f ) {
355
- double cumsum = 0 .0f ;
356
- for (int i = 0 ; i < top_k; i++) {
357
- cumsum += probs[i];
358
- if (cumsum >= top_p) {
359
- top_k = i + 1 ;
360
- probs.resize (top_k);
361
- logits_id.resize (top_k);
362
- break ;
363
- }
364
- }
365
-
366
- cumsum = 1.0 /cumsum;
367
- for (int i = 0 ; i < (int ) probs.size (); i++) {
368
- probs[i] *= cumsum;
369
- }
370
- }
371
-
372
- // printf("\n");
373
- // for (int i = 0; i < (int) probs.size(); i++) {
374
- // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
375
- // }
376
- // exit(0);
377
-
378
- std::discrete_distribution<> dist (probs.begin (), probs.end ());
379
- int idx = dist (rng);
380
-
381
- return logits_id[idx].second ;
382
315
}
383
316
384
- gpt_vocab::id llama_sample_top_p (
317
+ gpt_vocab::id llama_sample_top_p_top_k (
385
318
const gpt_vocab & vocab,
386
319
const float * logits,
387
320
std::vector<gpt_vocab::id> & last_n_tokens,
388
321
double repeat_penalty,
322
+ int top_k,
389
323
double top_p,
390
324
double temp,
391
325
std::mt19937 & rng) {
@@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
412
346
}
413
347
}
414
348
415
- std::sort (
416
- logits_id.begin (),
417
- logits_id.end (),
418
- [](const std::pair<double , gpt_vocab::id> & a, const std::pair<double , gpt_vocab::id> & b) {
419
- return a.first > b.first ;
420
- });
349
+ sample_top_k (logits_id, top_k);
421
350
422
351
double maxl = -INFINITY;
423
352
for (const auto & kv : logits_id) {
0 commit comments