Skip to content

Commit 17eaa96

Browse files
committed
Major refactoring and expanded features
1 parent 9f4458d commit 17eaa96

File tree

5 files changed

+388
-97
lines changed

5 files changed

+388
-97
lines changed

examples/common.cpp

Lines changed: 154 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,147 @@ void process_escapes(std::string& input) {
9191
input.resize(output_idx);
9292
}
9393

94+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) {
95+
assert(params != NULL);
96+
memset(params, 0, sizeof(llama_sampler_seqrep_params));
97+
params->last_n = 256;
98+
params->mid_word_scale = 0.1f;
99+
params->tolerance_half_step_cost = 1.0f;
100+
}
101+
102+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params) {
103+
if (fp == NULL) {
104+
return;
105+
}
106+
assert(params != NULL);
107+
fprintf(fp, "seqrep(last_n = %d, min_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_half_step_cost = %.4f, flags = %d)",
108+
params->last_n, params->min_length, params->start_offset, params->presence_penalty,
109+
params->length_penalty, params->tolerance, params->mid_word_scale, params->tolerance_match_credit,
110+
params->tolerance_half_step_cost, params->flags);
111+
}
112+
113+
void seqrep_sampler_help() {
114+
llama_sampler_seqrep_params p;
115+
seqrep_sampler_params_init(&p);
116+
fprintf(stderr, "==== Sequence Repetition Sampler Help ====\n\n");
117+
fprintf(stderr, " The sequence repetition sampler takes a configuration string in the format:\n");
118+
fprintf(stderr, " arg1:arg2:argN\n");
119+
fprintf(stderr, " A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n");
120+
fprintf(stderr, "\n- Available key/value arguments\n");
121+
fprintf(stderr, " * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n");
122+
fprintf(stderr, " * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n");
123+
fprintf(stderr, " * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n");
124+
fprintf(stderr, " * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n);
125+
fprintf(stderr, " * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length);
126+
fprintf(stderr, " * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", p.presence_penalty);
127+
fprintf(stderr, " * length_penalty\n penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", p.length_penalty);
128+
fprintf(stderr, " * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance);
129+
fprintf(stderr, " * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale);
130+
fprintf(stderr, " * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit);
131+
fprintf(stderr, " * tolerance_half_step_cost\n advanced option to adjust tolerance cost for failed matches within a half step of a match (default: %f, 1.0 = normal)\n", p.tolerance_half_step_cost);
132+
fprintf(stderr, "\n- Available flags arguments (currently all default to disabled)\n");
133+
fprintf(stderr, " * flag_immediate_wildcard\n when tolerance is consumed, by default it doesn't count as a match until a real match is found\n");
134+
fprintf(stderr, " * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n");
135+
fprintf(stderr, " * flag_tolerance_no_first\n do not allow using tolerance before the first match\n");
136+
fprintf(stderr, " * flag_tolerance_cap_initial\n only meaningful with match credit, prevents match credit adjusting tolerance higher than the initial value\n");
137+
fprintf(stderr, " * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n");
138+
fprintf(stderr, " * flag_divide_by_penalty\n divide the logit by when applying a penalty rather than subtracting it. warning: when this flag is enabled, 1.0 disables penalties not 0.0. 0.0 is probably not what you want\n");
139+
fprintf(stderr, "\n- Examples:\n");
140+
fprintf(stderr, " * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n");
141+
fprintf(stderr, " * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n");
142+
fprintf(stderr, " * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n");
143+
fprintf(stderr, " * min_length=3:tolerance=1:length_penalty=.2:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 0.2*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n");
144+
}
145+
146+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) {
147+
assert(params != NULL);
148+
assert(s != NULL);
149+
size_t offset = 0;
150+
std::string sparams = s;
151+
size_t slen = sparams.size();
152+
153+
while (offset < slen) {
154+
// printf("SR OFFS: %lu\n", offset);
155+
size_t argsep = sparams.find_first_of(':', offset);
156+
std::string argchunk;
157+
if (argsep == std::string::npos) {
158+
argchunk = sparams.substr(offset);
159+
} else if (argsep > offset) {
160+
argchunk = sparams.substr(offset, argsep - offset);
161+
}
162+
std::string argval;
163+
size_t valsep = argchunk.find_first_of('=');
164+
if (valsep != std::string::npos && valsep < argchunk.size()) {
165+
argval = argchunk.substr(valsep + 1);
166+
argchunk.resize(valsep);
167+
}
168+
// printf("SR: k[%s] = v[%s]\n", argchunk.c_str(), argval.c_str());
169+
if (argchunk.empty() && argval.empty()) {
170+
// pass
171+
} else if (argchunk == "repetition_mode") {
172+
params->last_n = 64;
173+
params->min_length = 1;
174+
params->mid_word_scale = 1.0f;
175+
params->flags = LLAMA_SEQREP_DIVIDE_BY_PENALTY;
176+
params->length_penalty = 1.0f;
177+
params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str());
178+
} else if (argchunk == "presence_mode") {
179+
params->last_n = 64;
180+
params->min_length = 1;
181+
params->mid_word_scale = 1.0f;
182+
params->flags = 0;
183+
params->length_penalty = 0.0f;
184+
params->presence_penalty = std::atof(argval.c_str());
185+
} else if (argchunk == "frequency_mode") {
186+
params->last_n = 64;
187+
params->min_length = 1;
188+
params->mid_word_scale = 1.0f;
189+
params->flags = 0;
190+
params->length_penalty = std::atof(argval.c_str());
191+
params->presence_penalty = 0.0f;
192+
} else if (argchunk == "flag_immediate_wildcard") {
193+
params->flags |= LLAMA_SEQREP_IMMEDIATE_WILDCARD;
194+
} else if (argchunk == "flag_tolerance_no_consecutive") {
195+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE;
196+
} else if (argchunk == "flag_tolerance_no_first") {
197+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST;
198+
} else if (argchunk == "flag_tolerance_cap_initial") {
199+
params->flags |= LLAMA_SEQREP_TOLERANCE_CAP_INITIAL;
200+
} else if (argchunk == "flag_penalize_length_max_seen") {
201+
params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN;
202+
} else if (argchunk == "flag_divide_by_penalty") {
203+
params->flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
204+
} else if (argchunk == "min_length") {
205+
params->min_length = std::atoi(argval.c_str());
206+
} else if (argchunk == "start_offset") {
207+
params->start_offset = std::atoi(argval.c_str());
208+
} else if (argchunk == "last_n") {
209+
params->last_n = std::atoi(argval.c_str());
210+
} else if (argchunk == "tolerance") {
211+
params->tolerance = std::atof(argval.c_str());
212+
} else if (argchunk == "presence_penalty") {
213+
params->presence_penalty = std::atof(argval.c_str());
214+
} else if (argchunk == "length_penalty") {
215+
params->length_penalty = std::atof(argval.c_str());
216+
} else if (argchunk == "mid_word_scale") {
217+
params->mid_word_scale = std::atof(argval.c_str());
218+
} else if (argchunk == "tolerance_match_credit") {
219+
params->tolerance_match_credit = std::atof(argval.c_str());
220+
} else if (argchunk == "tolerance_half_step_cost") {
221+
params->tolerance_half_step_cost = std::atof(argval.c_str());
222+
} else {
223+
fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str());
224+
return false;
225+
}
226+
if (argsep != std::string::npos) {
227+
offset = argsep + 1;
228+
} else {
229+
break;
230+
}
231+
}
232+
return true;
233+
}
234+
94235
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
95236
bool invalid_param = false;
96237
bool escape_prompt = false;
@@ -250,42 +391,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
250391
break;
251392
}
252393
params.presence_penalty = std::stof(argv[i]);
253-
} else if (arg == "--seqrep-last-n") {
394+
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
254395
if (++i >= argc) {
255396
invalid_param = true;
256397
break;
257398
}
258-
params.seqrep_last_n = std::stoi(argv[i]);
259-
} else if (arg == "--seqrep-min-len") {
260-
if (++i >= argc) {
261-
invalid_param = true;
262-
break;
399+
if (strcasecmp(argv[i], "help") == 0) {
400+
seqrep_sampler_help();
401+
exit(0);
263402
}
264-
params.seqrep_min_len = std::stoi(argv[i]);
265-
} else if (arg == "--seqrep-tolerance") {
266-
if (++i >= argc) {
267-
invalid_param = true;
403+
llama_sampler_seqrep_params sr_params;
404+
seqrep_sampler_params_init(&sr_params);
405+
invalid_param = !seqrep_sampler_params_parse(argv[i], &sr_params);
406+
if (invalid_param) {
268407
break;
269408
}
270-
params.seqrep_tolerance = std::stoi(argv[i]);
271-
} else if (arg == "--seqrep-ppenalty") {
272-
if (++i >= argc) {
273-
invalid_param = true;
274-
break;
275-
}
276-
params.seqrep_ppenalty = std::stof(argv[i]);
277-
} else if (arg == "--seqrep-lpenalty") {
278-
if (++i >= argc) {
279-
invalid_param = true;
280-
break;
281-
}
282-
params.seqrep_lpenalty = std::stof(argv[i]);
283-
} else if (arg == "--seqrep-mw-scale") {
284-
if (++i >= argc) {
285-
invalid_param = true;
286-
break;
409+
if (sr_params.last_n != 0 && sr_params.min_length > 0
410+
&& (sr_params.presence_penalty != 0.0f || sr_params.length_penalty != 0.0f)) {
411+
params.seqrep_params.push_back(sr_params);
287412
}
288-
params.seqrep_mw_scale = std::stof(argv[i]);
289413
} else if (arg == "--mirostat") {
290414
if (++i >= argc) {
291415
invalid_param = true;
@@ -592,12 +716,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
592716
fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
593717
fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
594718
fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
595-
fprintf(stdout, " --seqrep-last-n N last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", params.seqrep_last_n);
596-
fprintf(stdout, " --seqrep-min-len N minimum matching sequence length (default: %d, < 2 = disabled)\n", params.seqrep_min_len);
597-
fprintf(stdout, " --seqrep-tolerance N tolerance for fuzzy matching sequences (default: %d, 0 = disabled)\n", params.seqrep_tolerance);
598-
fprintf(stdout, " --seqrep-ppenalty N presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", params.seqrep_ppenalty);
599-
fprintf(stdout, " --seqrep-lpenalty N penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", params.seqrep_lpenalty);
600-
fprintf(stdout, " --seqrep-mw-scale N scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", params.seqrep_mw_scale);
719+
fprintf(stdout, " -seqrep CFG, --seqrep-penalty CFG\n");
720+
fprintf(stdout, " add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
601721
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
602722
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
603723
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);

examples/common.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ struct gpt_params {
4444
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
4545
float frequency_penalty = 0.00f; // 0.0 = disabled
4646
float presence_penalty = 0.00f; // 0.0 = disabled
47-
int32_t seqrep_last_n = 256; // last n tokens to penalize (0 = disable penalty, -1 = context size)
48-
int32_t seqrep_min_len = 0; // minimum sequence length to match (< 2 is disabled)
49-
int32_t seqrep_tolerance = 0; // tolerance for fuzzy sequence matching (0 = disabled)
50-
float seqrep_ppenalty = 0.0f; // flat penalty (0.0 = disabled)
51-
float seqrep_lpenalty = 0.0f; // stacking penalty based on length (0.0 = disabled)
52-
float seqrep_mw_scale = 0.1f; // scale penalty when applied to mid-word tokens (1.0 = apply full penalty)
47+
std::vector<llama_sampler_seqrep_params> seqrep_params;
5348
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
5449
float mirostat_tau = 5.00f; // target entropy
5550
float mirostat_eta = 0.10f; // learning rate
@@ -118,3 +113,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
118113

119114
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
120115
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
116+
117+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
118+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params);
119+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);

examples/main/main.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,15 @@ int main(int argc, char ** argv) {
334334
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
335335
}
336336
}
337-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, seqrep(last_n = %d, min_len = %d, tolerance = %d, ppenalty = %f, lpenalty = %f, mw_scale = %f), top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
337+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f",
338338
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty,
339-
params.seqrep_last_n, params.seqrep_min_len, params.seqrep_tolerance, params.seqrep_ppenalty, params.seqrep_lpenalty, params.seqrep_mw_scale,
340-
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
339+
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
340+
params.mirostat, params.mirostat_eta, params.mirostat_tau);
341+
for (auto & sr_params : params.seqrep_params) {
342+
fprintf(stderr, ", ");
343+
seqrep_sampler_params_dump(stderr, &sr_params);
344+
}
345+
fprintf(stderr, "\n");
341346
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
342347
fprintf(stderr, "\n\n");
343348

@@ -554,7 +559,6 @@ int main(int argc, char ** argv) {
554559
const float typical_p = params.typical_p;
555560
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
556561
const float repeat_penalty = params.repeat_penalty;
557-
const int32_t seqrep_last_n = params.seqrep_last_n < 0 ? n_ctx : params.seqrep_last_n;
558562
const float alpha_presence = params.presence_penalty;
559563
const float alpha_frequency = params.frequency_penalty;
560564
const int mirostat = params.mirostat;
@@ -600,11 +604,11 @@ int main(int argc, char ** argv) {
600604
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
601605
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
602606
last_n_repeat, alpha_frequency, alpha_presence);
603-
auto seqrep_last_n_repeat = std::min(std::min((int)last_n_tokens.size(), seqrep_last_n), n_ctx);
604-
llama_sample_seqrep_penalty(ctx, &candidates_p,
605-
last_n_tokens.data() + last_n_tokens.size() - seqrep_last_n_repeat,
606-
seqrep_last_n_repeat, params.seqrep_min_len, params.seqrep_tolerance,
607-
params.seqrep_ppenalty, params.seqrep_lpenalty, params.seqrep_mw_scale);
607+
608+
for (auto & sr_params : params.seqrep_params) {
609+
llama_sample_seqrep_penalty(ctx, &candidates_p, last_n_tokens.data(), last_n_tokens.size(), &sr_params);
610+
}
611+
608612
if (!penalize_nl) {
609613
logits[llama_token_nl()] = nl_logit;
610614
}

0 commit comments

Comments
 (0)