@@ -2825,27 +2825,32 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
2825
2825
const bool ends_on_word = params->mid_word_scale == 1 .0f
2826
2826
|| (llama_seqrep_check_word (ctx, last_tokens_p[last_tokens_size - 1 ]) & 2 ) != 0 ;
2827
2827
2828
- for (const auto it : penalize_tokens) {
2829
- const bool pt_starts_word = params->mid_word_scale == 1 .0f ||
2830
- (llama_seqrep_check_word (ctx, it.first ) & 1 ) != 0 ;
2831
- float scale = ends_on_word || pt_starts_word ? 1 .0f : params->mid_word_scale ;
2828
+ for (size_t i = 0 ; i < candidates->size ; ++i) {
2829
+ auto pt_iter = penalize_tokens.find (candidates->data [i].id );
2830
+ if (pt_iter == penalize_tokens.end ()) {
2831
+ continue ;
2832
+ }
2832
2833
2833
- float logit = candidates->data [it.first ].logit ;
2834
+ const size_t count = pt_iter->second ;
2835
+ const bool pt_starts_word = params->mid_word_scale == 1 .0f ||
2836
+ (llama_seqrep_check_word (ctx, candidates->data [i].id ) & 1 ) != 0 ;
2837
+ float penalty_scale = ends_on_word || pt_starts_word ? 1 .0f : params->mid_word_scale ;
2838
+ float logit = candidates->data [i].logit ;
2834
2839
2835
2840
if ((flags & LLAMA_SEQREP_DIVIDE_BY_PENALTY) == 0 ) {
2836
2841
float penalty =
2837
- ( float (it. second ) * params->length_penalty
2838
- + float (it. second > 0 ) * params->presence_penalty );
2839
- logit -= penalty * scale ;
2842
+ ( float (count ) * params->length_penalty
2843
+ + float (count > 0 ) * params->presence_penalty );
2844
+ logit -= penalty * penalty_scale ;
2840
2845
} else {
2841
2846
// This looks complicated. The point is to scale be able to scale penalties like
2842
2847
// 1.2. For example, suppose length penalty is 1.2 and length is 3. 1.2 * 3 = 3.6
2843
2848
// would be ridiculous. What we actually want is more like 1.6.
2844
2849
// An alternative approach would be to iteratively apply the scale.
2845
2850
// 10.0 / 1.6 == 6.25, however ((10.0 / 1.2) / 1.2) / 1.2 == 5.787
2846
2851
float penalty =
2847
- ( (float (it. second ) * (params->length_penalty - 1 .0f ))
2848
- + (float (it. second > 0 ) * (params->presence_penalty - 1 .0f )) ) * scale
2852
+ ( (float (count ) * (params->length_penalty - 1 .0f ))
2853
+ + (float (count > 0 ) * (params->presence_penalty - 1 .0f )) ) * penalty_scale
2849
2854
+ 1 .0f ;
2850
2855
if (logit <= 0 ) {
2851
2856
logit *= penalty;
@@ -2857,7 +2862,7 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
2857
2862
logit = 0 .0f ;
2858
2863
}
2859
2864
}
2860
- candidates->data [it. first ].logit = logit;
2865
+ candidates->data [i ].logit = logit;
2861
2866
}
2862
2867
2863
2868
candidates->sorted = false ;
0 commit comments