Skip to content

Commit 0875559

Browse files
committed
Initial implementation of a sequence repetition penalty
1 parent 6e91a1b commit 0875559

File tree

5 files changed

+492
-2
lines changed

5 files changed

+492
-2
lines changed

common/common.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,147 @@ void process_escapes(std::string& input) {
9191
input.resize(output_idx);
9292
}
9393

94+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) {
95+
assert(params != NULL);
96+
memset(params, 0, sizeof(llama_sampler_seqrep_params));
97+
params->last_n = 256;
98+
params->mid_word_scale = 0.1f;
99+
params->tolerance_half_step_cost = 1.0f;
100+
}
101+
102+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params) {
103+
if (fp == NULL) {
104+
return;
105+
}
106+
assert(params != NULL);
107+
fprintf(fp, "seqrep(last_n = %d, min_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_half_step_cost = %.4f, flags = %d)",
108+
params->last_n, params->min_length, params->start_offset, params->presence_penalty,
109+
params->length_penalty, params->tolerance, params->mid_word_scale, params->tolerance_match_credit,
110+
params->tolerance_half_step_cost, params->flags);
111+
}
112+
113+
void seqrep_sampler_help() {
114+
llama_sampler_seqrep_params p;
115+
seqrep_sampler_params_init(&p);
116+
fprintf(stdout, "==== Sequence Repetition Sampler Help ====\n\n");
117+
fprintf(stdout, " The sequence repetition sampler takes a configuration string in the format:\n");
118+
fprintf(stdout, " arg1:arg2:argN\n");
119+
fprintf(stdout, " A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n");
120+
fprintf(stdout, "\n- Available key/value arguments\n");
121+
fprintf(stdout, " * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n");
122+
fprintf(stdout, " * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n");
123+
fprintf(stdout, " * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n");
124+
fprintf(stdout, " * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n);
125+
fprintf(stdout, " * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length);
126+
fprintf(stdout, " * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", p.presence_penalty);
127+
fprintf(stdout, " * length_penalty\n penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", p.length_penalty);
128+
fprintf(stdout, " * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance);
129+
fprintf(stdout, " * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale);
130+
fprintf(stdout, " * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit);
131+
fprintf(stdout, " * tolerance_half_step_cost\n advanced option to adjust tolerance cost for failed matches within a half step of a match (default: %f, 1.0 = normal)\n", p.tolerance_half_step_cost);
132+
fprintf(stdout, "\n- Available flags arguments (currently all default to disabled)\n");
133+
fprintf(stdout, " * flag_immediate_wildcard\n when tolerance is consumed, by default it doesn't count as a match until a real match is found\n");
134+
fprintf(stdout, " * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n");
135+
fprintf(stdout, " * flag_tolerance_no_first\n do not allow using tolerance before the first match\n");
136+
fprintf(stdout, " * flag_tolerance_cap_initial\n only meaningful with match credit, prevents match credit adjusting tolerance higher than the initial value\n");
137+
fprintf(stdout, " * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n");
138+
fprintf(stdout, " * flag_divide_by_penalty\n divide the logit when applying a penalty rather than subtracting it. warning: when this flag is enabled, 1.0 disables penalties not 0.0. 0.0 is probably not what you want\n");
139+
fprintf(stdout, "\n- Examples:\n");
140+
fprintf(stdout, " * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n");
141+
fprintf(stdout, " * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n");
142+
fprintf(stdout, " * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n");
143+
fprintf(stdout, " * min_length=3:tolerance=1:length_penalty=.2:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 0.2*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n");
144+
}
145+
146+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) {
147+
assert(params != NULL);
148+
assert(s != NULL);
149+
size_t offset = 0;
150+
std::string sparams = s;
151+
size_t slen = sparams.size();
152+
153+
while (offset < slen) {
154+
// printf("SR OFFS: %lu\n", offset);
155+
size_t argsep = sparams.find_first_of(':', offset);
156+
std::string argchunk;
157+
if (argsep == std::string::npos) {
158+
argchunk = sparams.substr(offset);
159+
} else if (argsep > offset) {
160+
argchunk = sparams.substr(offset, argsep - offset);
161+
}
162+
std::string argval;
163+
size_t valsep = argchunk.find_first_of('=');
164+
if (valsep != std::string::npos && valsep < argchunk.size()) {
165+
argval = argchunk.substr(valsep + 1);
166+
argchunk.resize(valsep);
167+
}
168+
// printf("SR: k[%s] = v[%s]\n", argchunk.c_str(), argval.c_str());
169+
if (argchunk.empty() && argval.empty()) {
170+
// pass
171+
} else if (argchunk == "repetition_mode") {
172+
params->last_n = 64;
173+
params->min_length = 1;
174+
params->mid_word_scale = 1.0f;
175+
params->flags = LLAMA_SEQREP_DIVIDE_BY_PENALTY;
176+
params->length_penalty = 1.0f;
177+
params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str());
178+
} else if (argchunk == "presence_mode") {
179+
params->last_n = 64;
180+
params->min_length = 1;
181+
params->mid_word_scale = 1.0f;
182+
params->flags = 0;
183+
params->length_penalty = 0.0f;
184+
params->presence_penalty = std::atof(argval.c_str());
185+
} else if (argchunk == "frequency_mode") {
186+
params->last_n = 64;
187+
params->min_length = 1;
188+
params->mid_word_scale = 1.0f;
189+
params->flags = 0;
190+
params->length_penalty = std::atof(argval.c_str());
191+
params->presence_penalty = 0.0f;
192+
} else if (argchunk == "flag_immediate_wildcard") {
193+
params->flags |= LLAMA_SEQREP_IMMEDIATE_WILDCARD;
194+
} else if (argchunk == "flag_tolerance_no_consecutive") {
195+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE;
196+
} else if (argchunk == "flag_tolerance_no_first") {
197+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST;
198+
} else if (argchunk == "flag_tolerance_cap_initial") {
199+
params->flags |= LLAMA_SEQREP_TOLERANCE_CAP_INITIAL;
200+
} else if (argchunk == "flag_penalize_length_max_seen") {
201+
params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN;
202+
} else if (argchunk == "flag_divide_by_penalty") {
203+
params->flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
204+
} else if (argchunk == "min_length") {
205+
params->min_length = std::atoi(argval.c_str());
206+
} else if (argchunk == "start_offset") {
207+
params->start_offset = std::atoi(argval.c_str());
208+
} else if (argchunk == "last_n") {
209+
params->last_n = std::atoi(argval.c_str());
210+
} else if (argchunk == "tolerance") {
211+
params->tolerance = std::atof(argval.c_str());
212+
} else if (argchunk == "presence_penalty") {
213+
params->presence_penalty = std::atof(argval.c_str());
214+
} else if (argchunk == "length_penalty") {
215+
params->length_penalty = std::atof(argval.c_str());
216+
} else if (argchunk == "mid_word_scale") {
217+
params->mid_word_scale = std::atof(argval.c_str());
218+
} else if (argchunk == "tolerance_match_credit") {
219+
params->tolerance_match_credit = std::atof(argval.c_str());
220+
} else if (argchunk == "tolerance_half_step_cost") {
221+
params->tolerance_half_step_cost = std::atof(argval.c_str());
222+
} else {
223+
fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str());
224+
return false;
225+
}
226+
if (argsep != std::string::npos) {
227+
offset = argsep + 1;
228+
} else {
229+
break;
230+
}
231+
}
232+
return true;
233+
}
234+
94235
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
95236
bool invalid_param = false;
96237
bool escape_prompt = false;
@@ -238,6 +379,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
238379
break;
239380
}
240381
params.presence_penalty = std::stof(argv[i]);
382+
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
383+
if (++i >= argc) {
384+
invalid_param = true;
385+
break;
386+
}
387+
if (std::strcmp(argv[i], "help") == 0) {
388+
seqrep_sampler_help();
389+
exit(0);
390+
}
391+
llama_sampler_seqrep_params sr_params;
392+
seqrep_sampler_params_init(&sr_params);
393+
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
394+
seqrep_sampler_help();
395+
exit(1);
396+
}
397+
if (sr_params.last_n != 0 && sr_params.min_length > 0
398+
&& (sr_params.presence_penalty != 0.0f || sr_params.length_penalty != 0.0f)) {
399+
params.seqrep_params.push_back(sr_params);
400+
}
241401
} else if (arg == "--mirostat") {
242402
if (++i >= argc) {
243403
invalid_param = true;
@@ -568,6 +728,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
568728
fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
569729
fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
570730
fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
731+
fprintf(stdout, " -seqrep CFG, --seqrep-penalty CFG\n");
732+
fprintf(stdout, " add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
571733
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
572734
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
573735
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct gpt_params {
4141
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
4242
float frequency_penalty = 0.00f; // 0.0 = disabled
4343
float presence_penalty = 0.00f; // 0.0 = disabled
44+
std::vector<llama_sampler_seqrep_params> seqrep_params;
4445
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
4546
float mirostat_tau = 5.00f; // target entropy
4647
float mirostat_eta = 0.10f; // learning rate
@@ -123,3 +124,7 @@ std::vector<llama_token> llama_tokenize(
123124
std::string llama_token_to_str(
124125
const struct llama_context * ctx,
125126
llama_token token);
127+
128+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
129+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params);
130+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);

examples/main/main.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,15 @@ int main(int argc, char ** argv) {
334334
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
335335
}
336336
}
337-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
338-
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
337+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f",
338+
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty,
339+
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
340+
params.mirostat, params.mirostat_eta, params.mirostat_tau);
341+
for (auto & sr_params : params.seqrep_params) {
342+
fprintf(stderr, ", ");
343+
seqrep_sampler_params_dump(stderr, &sr_params);
344+
}
345+
fprintf(stderr, "\n");
339346
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
340347
fprintf(stderr, "\n\n");
341348

@@ -596,6 +603,11 @@ int main(int argc, char ** argv) {
596603
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
597604
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
598605
last_n_repeat, alpha_frequency, alpha_presence);
606+
607+
for (auto & sr_params : params.seqrep_params) {
608+
llama_sample_seqrep_penalty(ctx, &candidates_p, last_n_tokens.data(), last_n_tokens.size(), &sr_params);
609+
}
610+
599611
if (!penalize_nl) {
600612
logits[llama_token_nl(ctx)] = nl_logit;
601613
}

0 commit comments

Comments
 (0)