Skip to content

Commit 81a17c8

Browse files
MaggotHATEarthw
authored andcommitted
sampling : add XTC sampler (ggml-org#9742)
* Initial XTC commit Adds XTC sampler, not activated by default, but recommended settings by default. * Cleanup * Simplified chances calculation To be more inline with the original implementation, chance is calculated once at the beginning. * First fixes by comments Still need to look into sorting * Fixed trailing backspaces * Fixed RNG to be reproduceable Thanks to @slaren for directions * Fixed forgotten header * Moved `min_keep` Moved from conditions to a simple check at the end. * Fixed broken randomization Thanks to @slaren for explanation * Swapped sorting for a custom algorithm Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable. * Algorithm rework 1. Scan token from top till the first non-penalizable 2. Remove the last captured token (the least probable above threshold) 3. Shift all tokens to override the remaining penalizable 4. Penalize and put them at the the bottom. * Added XTC to `test-sampling` * Simplified algorithm and more tests * Updated info in common and args * Merged back lost commits in common and arg * Update dump info in common * Fixed incorrect min_keep check * Added XTC to README * Renamed parameters, fixed info and defaults * probability is at 0 by default, but XTC is included in sampling queue * threshold higher than 0.5 switches XTC off * Initial server support * Added XTC to server UIs * Fixed labels in old server UI * Made algorithm safer and more readable * Removed xtc_threshold_max * Fixed arg after update * Quick fixes by comments * Simplified algorithm since threshold_max is removed * Renamed random distribution * Fixed tests and outdated README * Small fixes
1 parent cb188ff commit 81a17c8

File tree

11 files changed

+195
-10
lines changed

11 files changed

+195
-10
lines changed

common/arg.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
947947
params.sparams.tfs_z = std::stof(value);
948948
}
949949
).set_sparam());
950+
add_opt(common_arg(
951+
{"--xtc-probability"}, "N",
952+
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
953+
[](common_params & params, const std::string & value) {
954+
params.sparams.xtc_probability = std::stof(value);
955+
}
956+
).set_sparam());
957+
add_opt(common_arg(
958+
{"--xtc-threshold"}, "N",
959+
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
960+
[](common_params & params, const std::string & value) {
961+
params.sparams.xtc_threshold = std::stof(value);
962+
}
963+
).set_sparam());
950964
add_opt(common_arg(
951965
{"--typical"}, "N",
952966
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),

common/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
21042104
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
21052105
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
21062106
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
2107+
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
2108+
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
21072109
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
21082110
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
21092111
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");

common/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ enum common_sampler_type {
9090
COMMON_SAMPLER_TYPE_TFS_Z = 4,
9191
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
9292
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
93+
COMMON_SAMPLER_TYPE_XTC = 7,
94+
9395
};
9496

9597
// dimensionality reduction methods, used by cvector-generator
@@ -108,6 +110,8 @@ struct common_sampler_params {
108110
int32_t top_k = 40; // <= 0 to use vocab size
109111
float top_p = 0.95f; // 1.0 = disabled
110112
float min_p = 0.05f; // 0.0 = disabled
113+
float xtc_probability = 0.00f; // 0.0 = disabled
114+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
111115
float tfs_z = 1.00f; // 1.0 = disabled
112116
float typ_p = 1.00f; // typical_p, 1.0 = disabled
113117
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
@@ -124,12 +128,14 @@ struct common_sampler_params {
124128
bool ignore_eos = false;
125129
bool no_perf = false; // disable performance metrics
126130

131+
127132
std::vector<enum common_sampler_type> samplers = {
128133
COMMON_SAMPLER_TYPE_TOP_K,
129134
COMMON_SAMPLER_TYPE_TFS_Z,
130135
COMMON_SAMPLER_TYPE_TYPICAL_P,
131136
COMMON_SAMPLER_TYPE_TOP_P,
132137
COMMON_SAMPLER_TYPE_MIN_P,
138+
COMMON_SAMPLER_TYPE_XTC,
133139
COMMON_SAMPLER_TYPE_TEMPERATURE
134140
};
135141

common/sampling.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
130130

131131
snprintf(result, sizeof(result),
132132
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
133+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
134134
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
135135
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
136-
top_k, tfs_z, top_p, min_p, typ_p, temp,
136+
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
137137
mirostat, mirostat_eta, mirostat_tau);
138138

139139
return std::string(result);
@@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
184184
case COMMON_SAMPLER_TYPE_MIN_P:
185185
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186186
break;
187+
case COMMON_SAMPLER_TYPE_XTC:
188+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189+
break;
187190
case COMMON_SAMPLER_TYPE_TFS_Z:
188191
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
189192
break;
@@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
372375
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
373376
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
374377
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
378+
case COMMON_SAMPLER_TYPE_XTC: return 'x';
375379
default : return '?';
376380
}
377381
}
@@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
384388
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
385389
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
386390
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391+
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
387392
default : return "";
388393
}
389394
}
@@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
396401
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
397402
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
398403
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
404+
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
399405
};
400406

401407
// since samplers names are written multiple ways
@@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
441447
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
442448
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
443449
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
444-
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
450+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
451+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
445452
};
446453

447454
std::vector<common_sampler_type> samplers;

examples/main/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
241241

242242
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
243243

244+
### XTC Sampling
245+
246+
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
247+
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
248+
249+
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
250+
251+
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
252+
253+
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
254+
255+
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
256+
244257
### Logit Bias
245258

246259
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.

examples/server/public/index-new.html

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
top_k: 0, // <= 0 to use vocab size
4444
top_p: 1.0, // 1.0 = disabled
4545
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
46+
xtc_probability: 0.0, // 0 = disabled;
47+
xtc_threshold: 0.1, // > 0.5 disables XTC;
4648
tfs_z: 1.0, // 1.0 = disabled
4749
typical_p: 1.0, // 1.0 = disabled
4850
presence_penalty: 0.0, // 0.0 = disabled
@@ -836,6 +838,8 @@
836838
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
837839
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
838840
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
841+
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
842+
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
839843
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
840844
</fieldset>
841845
@@ -1132,6 +1136,8 @@ <h2>llama.cpp</h2>
11321136
const snapSettings = {
11331137
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
11341138
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
1139+
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
1140+
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
11351141
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
11361142
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
11371143
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },

examples/server/public/index.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@
307307
top_k: 40, // <= 0 to use vocab size
308308
top_p: 0.95, // 1.0 = disabled
309309
min_p: 0.05, // 0 = disabled
310+
xtc_probability: 0.0, // 0 = disabled;
311+
xtc_threshold: 0.1, // > 0.5 disables XTC;
310312
tfs_z: 1.0, // 1.0 = disabled
311313
typical_p: 1.0, // 1.0 = disabled
312314
presence_penalty: 0.0, // 0.0 = disabled
@@ -1013,6 +1015,8 @@
10131015
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
10141016
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
10151017
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
1018+
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
1019+
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
10161020
</fieldset>
10171021
<hr />
10181022
<fieldset class="three">

examples/server/server.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,8 @@ struct server_context {
863863
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
864864
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
865865
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
866+
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
867+
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
866868
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
867869
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
868870
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
@@ -1196,6 +1198,8 @@ struct server_context {
11961198
{"top_k", slot.sparams.top_k},
11971199
{"top_p", slot.sparams.top_p},
11981200
{"min_p", slot.sparams.min_p},
1201+
{"xtc_probability", slot.sparams.xtc_probability},
1202+
{"xtc_threshold", slot.sparams.xtc_threshold},
11991203
{"tfs_z", slot.sparams.tfs_z},
12001204
{"typical_p", slot.sparams.typ_p},
12011205
{"repeat_last_n", slot.sparams.penalty_last_n},

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,9 @@ extern "C" {
11011101
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
11021102
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
11031103

1104+
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1105+
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
1106+
11041107
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
11051108
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
11061109
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

src/llama-sampling.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
10591059
};
10601060
}
10611061

1062+
// xtc
1063+
1064+
struct llama_sampler_xtc {
1065+
const float probability;
1066+
const float threshold;
1067+
const size_t min_keep;
1068+
1069+
const uint32_t seed;
1070+
uint32_t seed_cur;
1071+
1072+
std::mt19937 rng;
1073+
};
1074+
1075+
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1076+
return "xtc";
1077+
}
1078+
1079+
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1080+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1081+
1082+
if (ctx->probability <= 0.0f
1083+
|| ctx->threshold > 0.5f
1084+
|| cur_p->size < 2) {
1085+
return;
1086+
}
1087+
1088+
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1089+
float chance = distribution(ctx->rng);
1090+
if (chance > ctx->probability) return;
1091+
1092+
// in case it's not sorted/recalculated yet
1093+
llama_sampler_softmax_impl(cur_p);
1094+
1095+
int pos_last = 0;
1096+
1097+
for (size_t i = 0; i < cur_p->size; ++i) {
1098+
if (cur_p->data[i].p >= ctx->threshold) {
1099+
pos_last = i;
1100+
} else break;
1101+
}
1102+
1103+
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1104+
cur_p->data += pos_last;
1105+
cur_p->size -= pos_last;
1106+
}
1107+
}
1108+
1109+
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1110+
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1111+
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1112+
1113+
// copy the state
1114+
{
1115+
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1116+
1117+
result_ctx->rng = ctx->rng;
1118+
}
1119+
1120+
return result;
1121+
}
1122+
1123+
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1124+
delete (llama_sampler_xtc *) smpl->ctx;
1125+
}
1126+
1127+
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1128+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1129+
ctx->seed_cur = get_rng_seed(ctx->seed);
1130+
ctx->rng.seed(ctx->seed_cur);
1131+
}
1132+
1133+
static struct llama_sampler_i llama_sampler_xtc_i = {
1134+
/* .name = */ llama_sampler_xtc_name,
1135+
/* .accept = */ nullptr,
1136+
/* .apply = */ llama_sample_xtc_apply,
1137+
/* .reset = */ llama_sampler_xtc_reset,
1138+
/* .clone = */ llama_sampler_xtc_clone,
1139+
/* .free = */ llama_sampler_xtc_free,
1140+
};
1141+
1142+
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1143+
auto seed_cur = get_rng_seed(seed);
1144+
return new llama_sampler {
1145+
/* .iface = */ &llama_sampler_xtc_i,
1146+
/* .ctx = */ new llama_sampler_xtc {
1147+
/* .probability = */ p,
1148+
/* .threshold = */ t,
1149+
/* .min_keep = */ min_keep,
1150+
/* .seed = */ seed,
1151+
/* .seed_cur = */ seed_cur,
1152+
/* .rng = */ std::mt19937(seed_cur),
1153+
},
1154+
};
1155+
}
1156+
10621157
// mirostat
10631158

10641159
struct llama_sampler_mirostat {

0 commit comments

Comments
 (0)