@@ -215,8 +215,8 @@ int main(int argc, char ** argv) {
215
215
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
216
216
}
217
217
}
218
- fprintf (stderr, " sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " ,
219
- params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
218
+ fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = % f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n " ,
219
+ params.repeat_last_n , params.repeat_penalty , params. alpha_presence , params. alpha_frequency , params. top_k , params.tfs_z , params. top_p , params.typical_p , params.temp );
220
220
fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
221
221
fprintf (stderr, " \n\n " );
222
222
@@ -281,23 +281,69 @@ int main(int argc, char ** argv) {
281
281
282
282
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
283
283
// out of user input, sample next token
284
- const int32_t top_k = params.top_k ;
285
- const float top_p = params.top_p ;
286
284
const float temp = params.temp ;
285
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
286
+ const float top_p = params.top_p ;
287
+ const float tfs_z = params.tfs_z ;
288
+ const float typical_p = params.typical_p ;
289
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
287
290
const float repeat_penalty = params.repeat_penalty ;
291
+ const float alpha_presence = params.alpha_presence ;
292
+ const float alpha_frequency = params.alpha_frequency ;
288
293
289
294
llama_token id = 0 ;
290
295
291
296
{
292
297
auto logits = llama_get_logits (ctx);
298
+ auto n_vocab = llama_n_vocab (ctx);
293
299
294
300
if (params.ignore_eos ) {
295
- logits[llama_token_eos ()] = 0 ;
301
+ logits[llama_token_eos ()] = -INFINITY;
302
+ }
303
+
304
+ std::vector<llama_token_data> candidates;
305
+ candidates.reserve (n_vocab);
306
+ for (size_t i = 0 ; i < n_vocab; i++) {
307
+ candidates.emplace_back (i, logits[i], 0 .0f );
296
308
}
297
309
298
- id = llama_sample_top_p_top_k (ctx,
299
- last_n_tokens.data () + n_ctx - params.repeat_last_n ,
300
- params.repeat_last_n , top_k, top_p, temp, repeat_penalty);
310
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size () };
311
+
312
+ // Apply penalties
313
+ auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
314
+ llama_sample_repetition_penalty (&candidates_p,
315
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
316
+ last_n_repeat, repeat_penalty);
317
+ llama_sample_frequency_and_presence_penalties (&candidates_p,
318
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
319
+ last_n_repeat, alpha_frequency, alpha_presence);
320
+
321
+
322
+ #if 1
323
+ if (temp <= 0 ) {
324
+ // Greedy sampling
325
+ id = llama_sample_token_greedy (ctx, &candidates_p);
326
+ } else {
327
+ // Temperature sampling
328
+ llama_sample_top_k (&candidates_p, top_k);
329
+ llama_sample_tail_free (&candidates_p, tfs_z);
330
+ llama_sample_typical (&candidates_p, typical_p);
331
+ llama_sample_top_p (&candidates_p, top_p);
332
+
333
+ llama_sample_temperature (&candidates_p, temp);
334
+ // printf("`%d`", candidates_p.size);
335
+ id = llama_sample_token (ctx, &candidates_p);
336
+ }
337
+ #else
338
+ const float tau = 5.0f;
339
+ static float mu = 2.0f * tau;
340
+ static int k = 40;
341
+ const float eta = 0.1f;
342
+ const int m = 100;
343
+ const float N = n_vocab;
344
+ id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
345
+ // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
346
+ #endif
301
347
302
348
last_n_tokens.erase (last_n_tokens.begin ());
303
349
last_n_tokens.push_back (id);
0 commit comments