@@ -69,16 +69,11 @@ int main(int argc, char ** argv) {
69
69
printf (" \n first run: %s" , params.prompt .c_str ());
70
70
71
71
for (auto i = 0 ; i < params.n_predict ; i++) {
72
- auto * logits = llama_get_logits (ctx);
73
- auto n_vocab = llama_n_vocab (model);
72
+ const auto * logits = llama_get_logits (ctx);
74
73
75
- std::vector<llama_token_data> candidates;
76
- candidates.reserve (n_vocab);
77
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
78
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
79
- }
80
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
81
- auto next_token = llama_sampling_sample_dist (smpl, &candidates_p);
74
+ llama_sampling_set_logits (smpl, logits);
75
+
76
+ auto next_token = llama_sampling_sample_dist (smpl, nullptr );
82
77
auto next_token_str = llama_token_to_piece (ctx, next_token);
83
78
84
79
printf (" %s" , next_token_str.c_str ());
@@ -131,15 +126,11 @@ int main(int argc, char ** argv) {
131
126
132
127
// second run
133
128
for (auto i = 0 ; i < params.n_predict ; i++) {
134
- auto * logits = llama_get_logits (ctx2);
135
- auto n_vocab = llama_n_vocab (model);
136
- std::vector<llama_token_data> candidates;
137
- candidates.reserve (n_vocab);
138
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
139
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
140
- }
141
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
142
- auto next_token = llama_sampling_sample_dist (smpl2, &candidates_p);
129
+ const auto * logits = llama_get_logits (ctx2);
130
+
131
+ llama_sampling_set_logits (smpl2, logits);
132
+
133
+ auto next_token = llama_sampling_sample_dist (smpl2, nullptr );
143
134
auto next_token_str = llama_token_to_piece (ctx2, next_token);
144
135
145
136
printf (" %s" , next_token_str.c_str ());
@@ -224,15 +215,11 @@ int main(int argc, char ** argv) {
224
215
225
216
// third run with seq 1 instead of 0
226
217
for (auto i = 0 ; i < params.n_predict ; i++) {
227
- auto * logits = llama_get_logits (ctx3);
228
- auto n_vocab = llama_n_vocab (model);
229
- std::vector<llama_token_data> candidates;
230
- candidates.reserve (n_vocab);
231
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
232
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
233
- }
234
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
235
- auto next_token = llama_sampling_sample_dist (smpl3, &candidates_p);
218
+ const auto * logits = llama_get_logits (ctx3);
219
+
220
+ llama_sampling_set_logits (smpl3, logits);
221
+
222
+ auto next_token = llama_sampling_sample_dist (smpl3, nullptr );
236
223
auto next_token_str = llama_token_to_piece (ctx3, next_token);
237
224
238
225
printf (" %s" , next_token_str.c_str ());
0 commit comments