Skip to content

Commit b8623fc

Browse files
committed
Initial implementation of a sequence repetition penalty
1 parent b19edd5 commit b8623fc

File tree

5 files changed

+131
-2
lines changed

5 files changed

+131
-2
lines changed

examples/common.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
250250
break;
251251
}
252252
params.presence_penalty = std::stof(argv[i]);
253+
} else if (arg == "--seqrep-last-n") {
254+
if (++i >= argc) {
255+
invalid_param = true;
256+
break;
257+
}
258+
params.seqrep_last_n = std::stoi(argv[i]);
259+
} else if (arg == "--seqrep-min-len") {
260+
if (++i >= argc) {
261+
invalid_param = true;
262+
break;
263+
}
264+
params.seqrep_min_len = std::stoi(argv[i]);
265+
} else if (arg == "--seqrep-tolerance") {
266+
if (++i >= argc) {
267+
invalid_param = true;
268+
break;
269+
}
270+
params.seqrep_tolerance = std::stoi(argv[i]);
271+
} else if (arg == "--seqrep-ppenalty") {
272+
if (++i >= argc) {
273+
invalid_param = true;
274+
break;
275+
}
276+
params.seqrep_ppenalty = std::stof(argv[i]);
277+
} else if (arg == "--seqrep-lpenalty") {
278+
if (++i >= argc) {
279+
invalid_param = true;
280+
break;
281+
}
282+
params.seqrep_lpenalty = std::stof(argv[i]);
253283
} else if (arg == "--mirostat") {
254284
if (++i >= argc) {
255285
invalid_param = true;
@@ -556,6 +586,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
556586
fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
557587
fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
558588
fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
589+
fprintf(stdout, " --seqrep-last-n N last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", params.seqrep_last_n);
590+
fprintf(stdout, " --seqrep-min-len N minimum matching sequence length (default: %d, < 2 = disabled)\n", params.seqrep_min_len);
591+
fprintf(stdout, " --seqrep-tolerance N tolerance for fuzzy matching sequences (default: %d, 0 = disabled)\n", params.seqrep_tolerance);
592+
fprintf(stdout, " --seqrep-ppenalty N presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", params.seqrep_ppenalty);
593+
fprintf(stdout, " --seqrep-lpenalty N penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", params.seqrep_lpenalty);
559594
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
560595
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
561596
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);

examples/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ struct gpt_params {
4444
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
4545
float frequency_penalty = 0.00f; // 0.0 = disabled
4646
float presence_penalty = 0.00f; // 0.0 = disabled
47+
int32_t seqrep_last_n = 256; // last n tokens to penalize (0 = disable penalty, -1 = context size)
48+
int32_t seqrep_min_len = 0; // minimum sequence length to match (< 2 is disabled)
49+
int32_t seqrep_tolerance = 0; // tolerance for fuzzy sequence matching (0 = disabled)
50+
float seqrep_ppenalty = 0.0f; // flat penalty (0.0 = disabled)
51+
float seqrep_lpenalty = 0.0f; // stacking penalty based on length (0.0 = disabled)
4752
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
4853
float mirostat_tau = 5.00f; // target entropy
4954
float mirostat_eta = 0.10f; // learning rate

examples/main/main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,10 @@ int main(int argc, char ** argv) {
334334
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
335335
}
336336
}
337-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
338-
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
337+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, seqrep(last_n = %d, min_len = %d, tolerance = %d, ppenalty = %f, lpenalty = %f), top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
338+
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty,
339+
params.seqrep_last_n, params.seqrep_min_len, params.seqrep_tolerance, params.seqrep_ppenalty, params.seqrep_lpenalty,
340+
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
339341
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
340342
fprintf(stderr, "\n\n");
341343

@@ -552,6 +554,7 @@ int main(int argc, char ** argv) {
552554
const float typical_p = params.typical_p;
553555
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
554556
const float repeat_penalty = params.repeat_penalty;
557+
const int32_t seqrep_last_n = params.seqrep_last_n < 0 ? n_ctx : params.seqrep_last_n;
555558
const float alpha_presence = params.presence_penalty;
556559
const float alpha_frequency = params.frequency_penalty;
557560
const int mirostat = params.mirostat;
@@ -597,6 +600,11 @@ int main(int argc, char ** argv) {
597600
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
598601
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
599602
last_n_repeat, alpha_frequency, alpha_presence);
603+
auto seqrep_last_n_repeat = std::min(std::min((int)last_n_tokens.size(), seqrep_last_n), n_ctx);
604+
llama_sample_seqrep_penalty(ctx, &candidates_p,
605+
last_n_tokens.data() + last_n_tokens.size() - seqrep_last_n_repeat,
606+
seqrep_last_n_repeat, params.seqrep_min_len, params.seqrep_tolerance,
607+
params.seqrep_ppenalty, params.seqrep_lpenalty);
600608
if (!penalize_nl) {
601609
logits[llama_token_nl()] = nl_logit;
602610
}

llama.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2651,6 +2651,84 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
26512651
}
26522652
}
26532653

2654+
// Internal helper function for sequence matching.
2655+
static size_t llama_seqrep_find_match(const llama_token * last_tokens_p, const size_t last_tokens_size, int offset, const size_t min_length, int tolerance) {
2656+
2657+
if (min_length < 2 || last_tokens_size < min_length || (size_t)offset < min_length - 1) {
2658+
return 0;
2659+
}
2660+
2661+
int tail_offset = last_tokens_size - 1;
2662+
if (offset >= tail_offset) {
2663+
return 0;
2664+
}
2665+
int matches = 0, wildcard_matches = 0;
2666+
while (offset >= 0) {
2667+
if (last_tokens_p[offset] == last_tokens_p[tail_offset]) {
2668+
offset--;
2669+
tail_offset--;
2670+
matches += 1 + wildcard_matches;
2671+
wildcard_matches = 0;
2672+
continue;
2673+
}
2674+
if (tolerance < 1 || (offset == 0 && tail_offset == 0)) {
2675+
break;
2676+
}
2677+
tolerance--;
2678+
if (offset > 0 && last_tokens_p[offset - 1] == last_tokens_p[tail_offset]) {
2679+
offset--;
2680+
} else if (tail_offset > offset + 1 && last_tokens_p[offset] == last_tokens_p[tail_offset - 1]) {
2681+
tail_offset--;
2682+
} else {
2683+
// A tolerance charge can count as a match, but only if we can find a
2684+
// real match before the search is terminated.
2685+
wildcard_matches++;
2686+
offset--;
2687+
tail_offset--;
2688+
}
2689+
}
2690+
return matches;
2691+
}
2692+
2693+
void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty) {
2694+
if (min_length < 2 || last_tokens_size <= min_length ||
2695+
(flat_penalty == 0.0f && length_penalty == 0.0f)) {
2696+
return;
2697+
}
2698+
2699+
const int64_t t_start_sample_us = ggml_time_us();
2700+
2701+
// This will hold a map of token ids that can continue the sequence with its max seen sequence length.
2702+
std::unordered_map<llama_token, size_t> penalize_tokens;
2703+
2704+
for (size_t offset = last_tokens_size - 2; offset >= min_length - 1; offset--) {
2705+
const size_t matched_length =
2706+
llama_seqrep_find_match(last_tokens_p, last_tokens_size, offset, min_length, tolerance);
2707+
if (matched_length < min_length) {
2708+
continue;
2709+
}
2710+
2711+
// The token one past where we started trying to match is the one that could continue
2712+
// the previously observed sequence.
2713+
llama_token penalize_token = last_tokens_p[offset + 1];
2714+
2715+
auto pt_iter = penalize_tokens.find(penalize_token);
2716+
if (pt_iter == penalize_tokens.end()) {
2717+
penalize_tokens[penalize_token] = matched_length;
2718+
} else {
2719+
penalize_tokens[penalize_token] = pt_iter->second + matched_length;
2720+
}
2721+
}
2722+
for (const auto it : penalize_tokens) {
2723+
candidates->data[it.first].logit -=
2724+
float(it.second) * length_penalty + float(it.second > 0) * flat_penalty;
2725+
}
2726+
2727+
if (ctx) {
2728+
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2729+
}
2730+
}
2731+
26542732
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
26552733
assert(ctx);
26562734
const int64_t t_start_sample_us = ggml_time_us();

llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ extern "C" {
407407
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
408408
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
409409

410+
/// @details himom
411+
LLAMA_API void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty);
412+
410413
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
411414
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
412415
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.

0 commit comments

Comments
 (0)