Skip to content

Commit 02f0c6f

Browse files
beillerBillHamShopggerganov
authored
Add back top_k (#56)
* Add back top_k * Update utils.cpp * Update utils.h --------- Co-authored-by: Bill Hamilton <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent eb062bb commit 02f0c6f

File tree

3 files changed

+12
-89
lines changed

3 files changed

+12
-89
lines changed

main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ int main(int argc, char ** argv) {
825825

826826
if (i >= embd_inp.size()) {
827827
// sample next token
828+
const float top_k = params.top_k;
828829
const float top_p = params.top_p;
829830
const float temp = params.temp;
830831
const float repeat_penalty = params.repeat_penalty;
@@ -836,7 +837,7 @@ int main(int argc, char ** argv) {
836837
{
837838
const int64_t t_start_sample_us = ggml_time_us();
838839

839-
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
840+
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
840841

841842
last_n_tokens.erase(last_n_tokens.begin());
842843
last_n_tokens.push_back(id);

utils.cpp

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
301301
return true;
302302
}
303303

304-
gpt_vocab::id gpt_sample_top_k_top_p(
305-
const gpt_vocab & vocab,
306-
const float * logits,
307-
int top_k,
308-
double top_p,
309-
double temp,
310-
std::mt19937 & rng) {
311-
int n_logits = vocab.id_to_token.size();
312-
313-
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
314-
logits_id.reserve(n_logits);
315-
316-
{
317-
const double scale = 1.0/temp;
318-
for (int i = 0; i < n_logits; ++i) {
319-
logits_id.push_back(std::make_pair(logits[i]*scale, i));
320-
}
321-
}
322304

305+
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
323306
// find the top K tokens
324307
std::partial_sort(
325308
logits_id.begin(),
@@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
329312
});
330313

331314
logits_id.resize(top_k);
332-
333-
double maxl = -INFINITY;
334-
for (const auto & kv : logits_id) {
335-
maxl = std::max(maxl, kv.first);
336-
}
337-
338-
// compute probs for the top K tokens
339-
std::vector<double> probs;
340-
probs.reserve(logits_id.size());
341-
342-
double sum = 0.0;
343-
for (const auto & kv : logits_id) {
344-
double p = exp(kv.first - maxl);
345-
probs.push_back(p);
346-
sum += p;
347-
}
348-
349-
// normalize the probs
350-
for (auto & p : probs) {
351-
p /= sum;
352-
}
353-
354-
if (top_p < 1.0f) {
355-
double cumsum = 0.0f;
356-
for (int i = 0; i < top_k; i++) {
357-
cumsum += probs[i];
358-
if (cumsum >= top_p) {
359-
top_k = i + 1;
360-
probs.resize(top_k);
361-
logits_id.resize(top_k);
362-
break;
363-
}
364-
}
365-
366-
cumsum = 1.0/cumsum;
367-
for (int i = 0; i < (int) probs.size(); i++) {
368-
probs[i] *= cumsum;
369-
}
370-
}
371-
372-
//printf("\n");
373-
//for (int i = 0; i < (int) probs.size(); i++) {
374-
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
375-
//}
376-
//exit(0);
377-
378-
std::discrete_distribution<> dist(probs.begin(), probs.end());
379-
int idx = dist(rng);
380-
381-
return logits_id[idx].second;
382315
}
383316

384-
gpt_vocab::id llama_sample_top_p(
317+
gpt_vocab::id llama_sample_top_p_top_k(
385318
const gpt_vocab & vocab,
386319
const float * logits,
387320
std::vector<gpt_vocab::id> & last_n_tokens,
388321
double repeat_penalty,
322+
int top_k,
389323
double top_p,
390324
double temp,
391325
std::mt19937 & rng) {
@@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
412346
}
413347
}
414348

415-
std::sort(
416-
logits_id.begin(),
417-
logits_id.end(),
418-
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
419-
return a.first > b.first;
420-
});
349+
sample_top_k(logits_id, top_k);
421350

422351
double maxl = -INFINITY;
423352
for (const auto & kv : logits_id) {

utils.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct gpt_params {
1919
int32_t repeat_last_n = 64; // last n tokens to penalize
2020

2121
// sampling parameters
22-
int32_t top_k = 40; // unused
22+
int32_t top_k = 40;
2323
float top_p = 0.95f;
2424
float temp = 0.80f;
2525
float repeat_penalty = 1.30f;
@@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
7777
// - consider only the top K tokens
7878
// - from them, consider only the top tokens with cumulative probability > P
7979
//
80-
// TODO: not sure if this implementation is correct
81-
// TODO: temperature is not implemented
82-
//
83-
gpt_vocab::id gpt_sample_top_k_top_p(
84-
const gpt_vocab & vocab,
85-
const float * logits,
86-
int top_k,
87-
double top_p,
88-
double temp,
89-
std::mt19937 & rng);
90-
91-
gpt_vocab::id llama_sample_top_p(
80+
gpt_vocab::id llama_sample_top_p_top_k(
9281
const gpt_vocab & vocab,
9382
const float * logits,
9483
std::vector<gpt_vocab::id> & last_n_tokens,
9584
double repeat_penalty,
85+
int top_k,
9686
double top_p,
9787
double temp,
9888
std::mt19937 & rng);
9989

90+
// filer to top K tokens from list of logits
91+
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
92+
10093
//
10194
// Quantization
10295
//

0 commit comments

Comments
 (0)