2
2
3
3
#include < random>
4
4
5
- struct llama_sampling_context * llama_sampling_init (const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id ) {
5
+ struct llama_sampling_context * llama_sampling_init (const struct llama_sampling_params & params, struct llama_sampling * smpl ) {
6
6
struct llama_sampling_context * result = new llama_sampling_context ();
7
7
8
8
result->params = params;
9
- result->seq_id = seq_id;
10
- result->ctx = ctx;
9
+ result->smpl = smpl;
11
10
result->grammar = nullptr ;
12
11
13
12
// if there is a grammar, parse it
@@ -43,7 +42,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
43
42
44
43
result->n_valid = 0 ;
45
44
46
- llama_sampling_set_rng_seed (result, params.seed );
45
+ llama_sampling_set_rng_seed (result-> smpl , params.seed );
47
46
48
47
return result;
49
48
}
@@ -79,13 +78,6 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
79
78
ctx->n_valid = 0 ;
80
79
}
81
80
82
- void llama_sampling_set_rng_seed (struct llama_sampling_context * ctx, uint32_t seed) {
83
- if (seed == LLAMA_DEFAULT_SEED) {
84
- seed = std::random_device{}();
85
- }
86
- llama_set_rng_seed_seq (ctx->ctx , seed, ctx->seq_id );
87
- }
88
-
89
81
void llama_sampling_cp (llama_sampling_context * src, llama_sampling_context * dst) {
90
82
if (dst->grammar ) {
91
83
llama_grammar_free (dst->grammar );
@@ -230,10 +222,13 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
230
222
231
223
// no reasons to expose this function in header
232
224
static void sampler_queue (
233
- struct llama_context * ctx_main,
234
- const llama_sampling_params & params,
225
+ struct llama_sampling_context * ctx_sampling,
235
226
llama_token_data_array & cur_p,
236
227
size_t min_keep) {
228
+ llama_sampling * smpl = ctx_sampling->smpl ;
229
+
230
+ const llama_sampling_params & params = ctx_sampling->params ;
231
+
237
232
const float temp = params.temp ;
238
233
const float dynatemp_range = params.dynatemp_range ;
239
234
const float dynatemp_exponent = params.dynatemp_exponent ;
@@ -246,18 +241,18 @@ static void sampler_queue(
246
241
247
242
for (auto sampler_type : samplers_sequence) {
248
243
switch (sampler_type) {
249
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main , &cur_p, top_k, min_keep); break ;
250
- case llama_sampler_type::TFS_Z : llama_sample_tail_free (ctx_main , &cur_p, tfs_z, min_keep); break ;
251
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main , &cur_p, typical_p, min_keep); break ;
252
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main , &cur_p, top_p, min_keep); break ;
253
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main , &cur_p, min_p, min_keep); break ;
244
+ case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl , &cur_p, top_k, min_keep); break ;
245
+ case llama_sampler_type::TFS_Z : llama_sampling_tail_free (smpl , &cur_p, tfs_z, min_keep); break ;
246
+ case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl , &cur_p, typical_p, min_keep); break ;
247
+ case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl , &cur_p, top_p, min_keep); break ;
248
+ case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl , &cur_p, min_p, min_keep); break ;
254
249
case llama_sampler_type::TEMPERATURE:
255
250
if (dynatemp_range > 0 ) {
256
251
float dynatemp_min = std::max (0 .0f , temp - dynatemp_range);
257
252
float dynatemp_max = std::max (0 .0f , temp + dynatemp_range);
258
- llama_sample_entropy (ctx_main , &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
253
+ llama_sampling_entropy (smpl , &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
259
254
} else {
260
- llama_sample_temp (ctx_main , &cur_p, temp);
255
+ llama_sampling_temp (smpl , &cur_p, temp);
261
256
}
262
257
break ;
263
258
default : break ;
@@ -271,6 +266,8 @@ static llama_token llama_sampling_sample_impl(
271
266
struct llama_context * ctx_cfg,
272
267
const int idx,
273
268
bool is_resampling) {
269
+ llama_sampling * smpl = ctx_sampling->smpl ;
270
+
274
271
const llama_sampling_params & params = ctx_sampling->params ;
275
272
276
273
const float temp = params.temp ;
@@ -287,26 +284,26 @@ static llama_token llama_sampling_sample_impl(
287
284
288
285
if (temp < 0.0 ) {
289
286
// greedy sampling, with probs
290
- llama_sample_softmax (ctx_main , &cur_p);
287
+ llama_sampling_softmax (smpl , &cur_p);
291
288
id = cur_p.data [0 ].id ;
292
289
} else if (temp == 0.0 ) {
293
290
// greedy sampling, no probs
294
- id = llama_sample_token_greedy (ctx_main , &cur_p);
291
+ id = llama_sampling_sample_greedy (smpl , &cur_p);
295
292
} else {
296
293
if (mirostat == 1 ) {
297
294
const int mirostat_m = 100 ;
298
- llama_sample_temp (ctx_main , &cur_p, temp);
299
- id = llama_sample_token_mirostat (ctx_main , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu );
295
+ llama_sampling_temp (smpl , &cur_p, temp);
296
+ id = llama_sampling_sample_mirostat (smpl , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu );
300
297
} else if (mirostat == 2 ) {
301
- llama_sample_temp (ctx_main , &cur_p, temp);
302
- id = llama_sample_token_mirostat_v2 (ctx_main , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu );
298
+ llama_sampling_temp (smpl , &cur_p, temp);
299
+ id = llama_sampling_sample_mirostat_v2 (smpl , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu );
303
300
} else {
304
301
// temperature sampling
305
302
size_t min_keep = std::max (1 , params.min_keep );
306
303
307
- sampler_queue (ctx_main, params , cur_p, min_keep);
304
+ sampler_queue (ctx_sampling , cur_p, min_keep);
308
305
309
- id = llama_sample_token_seq (ctx_main , &cur_p, ctx_sampling-> seq_id );
306
+ id = llama_sampling_sample (smpl , &cur_p);
310
307
311
308
// {
312
309
// const int n_top = 10;
@@ -315,11 +312,11 @@ static llama_token llama_sampling_sample_impl(
315
312
// for (int i = 0; i < n_top; i++) {
316
313
// const llama_token id = cur_p.data[i].id;
317
314
// (void)id; // To avoid a warning that id is unused when logging is disabled.
318
- // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main , id).c_str(), cur_p.data[i].p);
315
+ // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl , id).c_str(), cur_p.data[i].p);
319
316
// }
320
317
// }
321
318
322
- // LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main , id).c_str());
319
+ // LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl , id).c_str());
323
320
}
324
321
}
325
322
@@ -360,6 +357,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
360
357
const int idx,
361
358
bool apply_grammar,
362
359
std::vector<float > * original_logits) {
360
+ llama_sampling * smpl = ctx_sampling->smpl ;
361
+
363
362
const llama_sampling_params & params = ctx_sampling->params ;
364
363
365
364
const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -390,7 +389,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
390
389
391
390
if (ctx_cfg) {
392
391
float * logits_guidance = llama_get_logits_ith (ctx_cfg, idx);
393
- llama_sample_apply_guidance (ctx_main , logits, logits_guidance, params.cfg_scale );
392
+ llama_sampling_apply_guidance (smpl , logits, logits_guidance, params.cfg_scale );
394
393
}
395
394
396
395
cur.resize (n_vocab);
@@ -407,7 +406,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
407
406
if (penalty_tokens_used_size) {
408
407
const float nl_logit = logits[llama_token_nl (llama_get_model (ctx_main))];
409
408
410
- llama_sample_repetition_penalties (ctx_main , &cur_p,
409
+ llama_sampling_repetition_penalties (smpl , &cur_p,
411
410
penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
412
411
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
413
412
@@ -445,7 +444,7 @@ llama_token_data_array llama_sampling_prepare(
445
444
const int idx,
446
445
bool apply_grammar,
447
446
std::vector<float > * original_logits) {
448
- return llama_sampling_prepare_impl (ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
447
+ return llama_sampling_prepare_impl (ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
449
448
}
450
449
451
450
void llama_sampling_accept (
0 commit comments