Skip to content

Commit 84ee695

Browse files
committed
Fix a serious issue with addressing candidates by tokenid
1 parent 46fad6b commit 84ee695

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

llama.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,27 +2825,32 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
28252825
const bool ends_on_word = params->mid_word_scale == 1.0f
28262826
|| (llama_seqrep_check_word(ctx, last_tokens_p[last_tokens_size - 1]) & 2) != 0;
28272827

2828-
for (const auto it : penalize_tokens) {
2829-
const bool pt_starts_word = params->mid_word_scale == 1.0f ||
2830-
(llama_seqrep_check_word(ctx, it.first) & 1) != 0;
2831-
float scale = ends_on_word || pt_starts_word ? 1.0f : params->mid_word_scale;
2828+
for (size_t i = 0; i < candidates->size; ++i) {
2829+
auto pt_iter = penalize_tokens.find(candidates->data[i].id);
2830+
if (pt_iter == penalize_tokens.end()) {
2831+
continue;
2832+
}
28322833

2833-
float logit = candidates->data[it.first].logit;
2834+
const size_t count = pt_iter->second;
2835+
const bool pt_starts_word = params->mid_word_scale == 1.0f ||
2836+
(llama_seqrep_check_word(ctx, candidates->data[i].id) & 1) != 0;
2837+
float penalty_scale = ends_on_word || pt_starts_word ? 1.0f : params->mid_word_scale;
2838+
float logit = candidates->data[i].logit;
28342839

28352840
if ((flags & LLAMA_SEQREP_DIVIDE_BY_PENALTY) == 0) {
28362841
float penalty =
2837-
( float(it.second) * params->length_penalty
2838-
+ float(it.second > 0) * params->presence_penalty );
2839-
logit -= penalty * scale;
2842+
( float(count) * params->length_penalty
2843+
+ float(count > 0) * params->presence_penalty );
2844+
logit -= penalty * penalty_scale;
28402845
} else {
28412846
// This looks complicated. The point is to scale be able to scale penalties like
28422847
// 1.2. For example, suppose length penalty is 1.2 and length is 3. 1.2 * 3 = 3.6
28432848
// would be ridiculous. What we actually want is more like 1.6.
28442849
// An alternative approach would be to iteratively apply the scale.
28452850
// 10.0 / 1.6 == 6.25, however ((10.0 / 1.2) / 1.2) / 1.2 == 5.787
28462851
float penalty =
2847-
( (float(it.second) * (params->length_penalty - 1.0f))
2848-
+ (float(it.second > 0) * (params->presence_penalty - 1.0f)) ) * scale
2852+
( (float(count) * (params->length_penalty - 1.0f))
2853+
+ (float(count > 0) * (params->presence_penalty - 1.0f)) ) * penalty_scale
28492854
+ 1.0f;
28502855
if (logit <= 0) {
28512856
logit *= penalty;
@@ -2857,7 +2862,7 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
28572862
logit = 0.0f;
28582863
}
28592864
}
2860-
candidates->data[it.first].logit = logit;
2865+
candidates->data[i].logit = logit;
28612866
}
28622867

28632868
candidates->sorted = false;

0 commit comments

Comments
 (0)