@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
229
229
params.logit_bias .data ()));
230
230
231
231
if (params.mirostat == 0 ) {
232
- if (params.top_n_sigma >= 0 ) {
233
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
234
- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
235
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
236
- } else {
237
- for (const auto & cnstr : params.samplers ) {
238
- switch (cnstr) {
239
- case COMMON_SAMPLER_TYPE_DRY:
240
- {
241
- std::vector<const char *> c_breakers;
242
- c_breakers.reserve (params.dry_sequence_breakers .size ());
243
- for (const auto & str : params.dry_sequence_breakers ) {
244
- c_breakers.push_back (str.c_str ());
245
- }
246
-
247
- llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
232
+ for (const auto & cnstr : params.samplers ) {
233
+ switch (cnstr) {
234
+ case COMMON_SAMPLER_TYPE_DRY:
235
+ {
236
+ std::vector<const char *> c_breakers;
237
+ c_breakers.reserve (params.dry_sequence_breakers .size ());
238
+ for (const auto & str : params.dry_sequence_breakers ) {
239
+ c_breakers.push_back (str.c_str ());
248
240
}
249
- break ;
250
- case COMMON_SAMPLER_TYPE_TOP_K:
251
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
252
- break ;
253
- case COMMON_SAMPLER_TYPE_TOP_P:
254
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
255
- break ;
256
- case COMMON_SAMPLER_TYPE_MIN_P:
257
- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
258
- break ;
259
- case COMMON_SAMPLER_TYPE_XTC:
260
- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
261
- break ;
262
- case COMMON_SAMPLER_TYPE_TYPICAL_P:
263
- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
264
- break ;
265
- case COMMON_SAMPLER_TYPE_TEMPERATURE:
266
- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
267
- break ;
268
- case COMMON_SAMPLER_TYPE_INFILL:
269
- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
270
- break ;
271
- case COMMON_SAMPLER_TYPE_PENALTIES:
272
- llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
273
- break ;
274
- default :
275
- GGML_ASSERT (false && " unknown sampler type" );
276
- }
241
+
242
+ llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
243
+ }
244
+ break ;
245
+ case COMMON_SAMPLER_TYPE_TOP_K:
246
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
247
+ break ;
248
+ case COMMON_SAMPLER_TYPE_TOP_P:
249
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
250
+ break ;
251
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
252
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
253
+ break ;
254
+ case COMMON_SAMPLER_TYPE_MIN_P:
255
+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
256
+ break ;
257
+ case COMMON_SAMPLER_TYPE_XTC:
258
+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
259
+ break ;
260
+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
261
+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
262
+ break ;
263
+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
264
+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
265
+ break ;
266
+ case COMMON_SAMPLER_TYPE_INFILL:
267
+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
268
+ break ;
269
+ case COMMON_SAMPLER_TYPE_PENALTIES:
270
+ llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
271
+ break ;
272
+ default :
273
+ GGML_ASSERT (false && " unknown sampler type" );
277
274
}
278
275
}
279
276
llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
@@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
475
472
case COMMON_SAMPLER_TYPE_TOP_K: return ' k' ;
476
473
case COMMON_SAMPLER_TYPE_TYPICAL_P: return ' y' ;
477
474
case COMMON_SAMPLER_TYPE_TOP_P: return ' p' ;
475
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return ' s' ;
478
476
case COMMON_SAMPLER_TYPE_MIN_P: return ' m' ;
479
477
case COMMON_SAMPLER_TYPE_TEMPERATURE: return ' t' ;
480
478
case COMMON_SAMPLER_TYPE_XTC: return ' x' ;
@@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
490
488
case COMMON_SAMPLER_TYPE_TOP_K: return " top_k" ;
491
489
case COMMON_SAMPLER_TYPE_TYPICAL_P: return " typ_p" ;
492
490
case COMMON_SAMPLER_TYPE_TOP_P: return " top_p" ;
491
+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return " top_n_sigma" ;
493
492
case COMMON_SAMPLER_TYPE_MIN_P: return " min_p" ;
494
493
case COMMON_SAMPLER_TYPE_TEMPERATURE: return " temperature" ;
495
494
case COMMON_SAMPLER_TYPE_XTC: return " xtc" ;
@@ -504,6 +503,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
504
503
{ " dry" , COMMON_SAMPLER_TYPE_DRY },
505
504
{ " top_k" , COMMON_SAMPLER_TYPE_TOP_K },
506
505
{ " top_p" , COMMON_SAMPLER_TYPE_TOP_P },
506
+ { " top_n_sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
507
507
{ " typ_p" , COMMON_SAMPLER_TYPE_TYPICAL_P },
508
508
{ " min_p" , COMMON_SAMPLER_TYPE_MIN_P },
509
509
{ " temperature" , COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -517,6 +517,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
517
517
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
518
518
{ " top-k" , COMMON_SAMPLER_TYPE_TOP_K },
519
519
{ " top-p" , COMMON_SAMPLER_TYPE_TOP_P },
520
+ { " top-n-sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
520
521
{ " nucleus" , COMMON_SAMPLER_TYPE_TOP_P },
521
522
{ " typical-p" , COMMON_SAMPLER_TYPE_TYPICAL_P },
522
523
{ " typical" , COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -552,6 +553,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
552
553
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
553
554
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
554
555
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
556
+ { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
555
557
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
556
558
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
557
559
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
0 commit comments