Skip to content

Commit 233461f

Browse files
sampling : Integrate Top-nσ into main sampling chain (and add it to the server) (#13264)
* sampling: add Top-nσ sampler to `llama-server` and sampler ordering * revert: sampler ordering * revert: VS' crappy auto-formatting * revert: VS' crappy auto-formatting pt.2 * revert: my crappy eye sight... * sampling: add XTC to Top-nσ sampler chain * sampling: add Dyna. Temp. to Top-nσ sampler chain * sampling: actually remove Top-nσ from sampler(oops) * Integrate top_n_sigma into main sampler chain * Define COMMON_SAMPLER_TYPE_TOP_N_SIGMA * Formatting * Lint * Exit early in the sampler if nsigma < 0 --------- Co-authored-by: CasualAutopsy <[email protected]>
1 parent b34c859 commit 233461f

File tree

4 files changed

+54
-44
lines changed

4 files changed

+54
-44
lines changed

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ enum common_sampler_type {
9696
COMMON_SAMPLER_TYPE_XTC = 8,
9797
COMMON_SAMPLER_TYPE_INFILL = 9,
9898
COMMON_SAMPLER_TYPE_PENALTIES = 10,
99+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
99100
};
100101

101102
// dimensionality reduction methods, used by cvector-generator
@@ -161,6 +162,7 @@ struct common_params_sampling {
161162
std::vector<enum common_sampler_type> samplers = {
162163
COMMON_SAMPLER_TYPE_PENALTIES,
163164
COMMON_SAMPLER_TYPE_DRY,
165+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
164166
COMMON_SAMPLER_TYPE_TOP_K,
165167
COMMON_SAMPLER_TYPE_TYPICAL_P,
166168
COMMON_SAMPLER_TYPE_TOP_P,

common/sampling.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
229229
params.logit_bias.data()));
230230

231231
if (params.mirostat == 0) {
232-
if (params.top_n_sigma >= 0) {
233-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
234-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
235-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
236-
} else {
237-
for (const auto & cnstr : params.samplers) {
238-
switch (cnstr) {
239-
case COMMON_SAMPLER_TYPE_DRY:
240-
{
241-
std::vector<const char *> c_breakers;
242-
c_breakers.reserve(params.dry_sequence_breakers.size());
243-
for (const auto & str : params.dry_sequence_breakers) {
244-
c_breakers.push_back(str.c_str());
245-
}
246-
247-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
232+
for (const auto & cnstr : params.samplers) {
233+
switch (cnstr) {
234+
case COMMON_SAMPLER_TYPE_DRY:
235+
{
236+
std::vector<const char *> c_breakers;
237+
c_breakers.reserve(params.dry_sequence_breakers.size());
238+
for (const auto & str : params.dry_sequence_breakers) {
239+
c_breakers.push_back(str.c_str());
248240
}
249-
break;
250-
case COMMON_SAMPLER_TYPE_TOP_K:
251-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
252-
break;
253-
case COMMON_SAMPLER_TYPE_TOP_P:
254-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
255-
break;
256-
case COMMON_SAMPLER_TYPE_MIN_P:
257-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
258-
break;
259-
case COMMON_SAMPLER_TYPE_XTC:
260-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
261-
break;
262-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
263-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
264-
break;
265-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
266-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
267-
break;
268-
case COMMON_SAMPLER_TYPE_INFILL:
269-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
270-
break;
271-
case COMMON_SAMPLER_TYPE_PENALTIES:
272-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
273-
break;
274-
default:
275-
GGML_ASSERT(false && "unknown sampler type");
276-
}
241+
242+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
243+
}
244+
break;
245+
case COMMON_SAMPLER_TYPE_TOP_K:
246+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
247+
break;
248+
case COMMON_SAMPLER_TYPE_TOP_P:
249+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
250+
break;
251+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
252+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
253+
break;
254+
case COMMON_SAMPLER_TYPE_MIN_P:
255+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
256+
break;
257+
case COMMON_SAMPLER_TYPE_XTC:
258+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
259+
break;
260+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
261+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
262+
break;
263+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
264+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
265+
break;
266+
case COMMON_SAMPLER_TYPE_INFILL:
267+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
268+
break;
269+
case COMMON_SAMPLER_TYPE_PENALTIES:
270+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
271+
break;
272+
default:
273+
GGML_ASSERT(false && "unknown sampler type");
277274
}
278275
}
279276
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
475472
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
476473
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
477474
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
475+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
478476
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
479477
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
480478
case COMMON_SAMPLER_TYPE_XTC: return 'x';
@@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
490488
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
491489
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
492490
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
491+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
493492
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
494493
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
495494
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@@ -504,6 +503,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
504503
{ "dry", COMMON_SAMPLER_TYPE_DRY },
505504
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
506505
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
506+
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
507507
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
508508
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
509509
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -517,6 +517,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
517517
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
518518
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
519519
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
520+
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
520521
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
521522
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
522523
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -552,6 +553,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
552553
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
553554
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
554555
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
556+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
555557
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
556558
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
557559
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },

src/llama-sampling.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,10 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
17501750
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
17511751
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
17521752

1753+
if (ctx->n < 0.0f) {
1754+
return;
1755+
}
1756+
17531757
// find max logit and calculate mean
17541758
float max = cur_p->data[0].logit;
17551759
float logits_sum = 0;

tools/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ struct slot_params {
146146
{"top_k", sampling.top_k},
147147
{"top_p", sampling.top_p},
148148
{"min_p", sampling.min_p},
149+
{"top_n_sigma", sampling.top_n_sigma},
149150
{"xtc_probability", sampling.xtc_probability},
150151
{"xtc_threshold", sampling.xtc_threshold},
151152
{"typical_p", sampling.typ_p},
@@ -248,6 +249,7 @@ struct server_task {
248249
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
249250
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
250251
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
252+
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
251253
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
252254
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
253255
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);

0 commit comments

Comments
 (0)