@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
64
64
// first run
65
65
printf (" \n %s" , params.prompt .c_str ());
66
66
for (auto i = 0 ; i < params.n_predict ; i++) {
67
- auto next_token = llama_sample_top_p_top_k (
68
- ctx,
69
- &last_n_tokens_data.back () - params.repeat_last_n ,
70
- params.repeat_last_n ,
71
- 40 ,
72
- 1.0 ,
73
- 1.0 ,
74
- 1.1 );
67
+ auto logits = llama_get_logits (ctx);
68
+ auto n_vocab = llama_n_vocab (ctx);
69
+ std::vector<llama_token_data> candidates;
70
+ candidates.reserve (n_vocab);
71
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
72
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
73
+ }
74
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
75
+ auto next_token = llama_sample_token (ctx, &candidates_p);
75
76
auto next_token_str = llama_token_to_str (ctx, next_token);
76
77
last_n_tokens_data.push_back (next_token);
77
78
printf (" %s" , next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
106
107
107
108
// second run
108
109
for (auto i = 0 ; i < params.n_predict ; i++) {
109
- auto next_token = llama_sample_top_p_top_k (
110
- ctx2,
111
- &last_n_tokens_data.back () - params.repeat_last_n ,
112
- params.repeat_last_n ,
113
- 40 ,
114
- 1.0 ,
115
- 1.0 ,
116
- 1.1 );
110
+ auto logits = llama_get_logits (ctx2);
111
+ auto n_vocab = llama_n_vocab (ctx2);
112
+ std::vector<llama_token_data> candidates;
113
+ candidates.reserve (n_vocab);
114
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
115
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
116
+ }
117
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
118
+ auto next_token = llama_sample_token (ctx2, &candidates_p);
117
119
auto next_token_str = llama_token_to_str (ctx2, next_token);
118
120
last_n_tokens_data.push_back (next_token);
119
121
printf (" %s" , next_token_str);
0 commit comments