5
5
#undef NDEBUG
6
6
#endif
7
7
8
+ #include < cstring>
8
9
#include < cmath>
9
10
#include < numeric>
10
11
#include < cassert>
@@ -173,6 +174,79 @@ void test_frequency_presence_penalty(
173
174
}
174
175
}
175
176
177
+
178
+ // NOTE: Compares expected_probs at id position, not sorted position like the other
179
+ // test functions.
180
+ void test_seqrep_penalty (
181
+ const std::vector<float > & probs,
182
+ const std::vector<llama_token> & last_tokens,
183
+ const std::vector<float > & expected_probs,
184
+ const llama_sampler_seqrep_params * params) {
185
+ assert (probs.size () == expected_probs.size ());
186
+
187
+ size_t n_vocab = probs.size ();
188
+ std::vector<llama_token_data> candidates;
189
+ candidates.reserve (n_vocab);
190
+ for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
191
+ float logit = log (probs[token_id]);
192
+ candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
193
+ }
194
+
195
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
196
+ llama_sample_softmax (nullptr , &candidates_p);
197
+ DUMP (&candidates_p);
198
+ llama_sample_seqrep_penalty (nullptr , &candidates_p, (llama_token *) last_tokens.data (), last_tokens.size (), params);
199
+ llama_sample_softmax (nullptr , &candidates_p);
200
+ DUMP (&candidates_p);
201
+
202
+ assert (candidates_p.size == expected_probs.size ());
203
+ for (size_t i = 0 ; i < candidates_p.size ; i++) {
204
+ assert (fabs (candidates_p.data [i].p - expected_probs[candidates_p.data [i].id ]) < 1e-3 );
205
+ }
206
+ }
207
+
208
+ void run_seqrep_tests (void ) {
209
+ llama_sampler_seqrep_params params;
210
+
211
+ // Compatible with frequency/presence penalty
212
+ memset (¶ms, 0 , sizeof (llama_sampler_seqrep_params));
213
+ params.last_n = 1024 ;
214
+ params.min_length = 1 ;
215
+ params.mid_word_scale = 1 .0f ;
216
+ params.presence_penalty = 5 .0f ;
217
+ params.length_penalty = 5 .0f ;
218
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .000011f , 0 .249997f , 0 .249997f , 0 .249997f , 0 .249997f }, ¶ms);
219
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .000023f , 0 .000023f , 0 .000023f , 0 .499966f , 0 .499966f }, ¶ms);
220
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .000000f , 0 .000023f , 0 .000023f , 0 .499977f , 0 .499977f }, ¶ms);
221
+
222
+ // Compatible with repetition penalty
223
+ memset (¶ms, 0 , sizeof (llama_sampler_seqrep_params));
224
+ params.last_n = 1024 ;
225
+ params.min_length = 1 ;
226
+ params.mid_word_scale = 1 .0f ;
227
+ params.presence_penalty = 50 .0f ;
228
+ params.length_penalty = 1 .0f ;
229
+ params.flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
230
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 , 0 .25f , 0 .25f , 0 .25f , 0 .25f }, ¶ms);
231
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 , 0 , 0 , 0 .5f , 0 .5f }, ¶ms);
232
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 , 0 , 0 , 0 .5f , 0 .5f }, ¶ms);
233
+
234
+ // Seqrep mode
235
+ memset (¶ms, 0 , sizeof (llama_sampler_seqrep_params));
236
+ params.last_n = 1024 ;
237
+ params.min_length = 3 ;
238
+ params.mid_word_scale = 1 .0f ;
239
+ params.tolerance_half_step_cost = 1 .0f ;
240
+ params.presence_penalty = 50 .0f ;
241
+ params.length_penalty = 1 .0f ;
242
+ params.flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
243
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 3 , 0 , 1 , 2 }, {0 .25f , 0 .25f , 0 .25f , 0 , 0 .25f }, ¶ms);
244
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 2 , 2 , 3 , 0 , 1 , 2 }, {0 .20f , 0 .20f , 0 .20f , 0 .20f , 0 .20f }, ¶ms);
245
+ params.tolerance = 1 .0f ;
246
+ test_seqrep_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 2 , 2 , 3 , 0 , 1 , 2 }, {0 .25f , 0 .25f , 0 .25f , 0 , 0 .25f }, ¶ms);
247
+ }
248
+
249
+
176
250
int main (void ) {
177
251
ggml_time_init ();
178
252
@@ -199,6 +273,8 @@ int main(void) {
199
273
test_frequency_presence_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .499966f , 0 .499966f , 0 .000023f , 0 .000023f , 0 .000023f }, 5 .0f , 5 .0f );
200
274
test_frequency_presence_penalty ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .499977f , 0 .499977f , 0 .000023f , 0 .000023f , 0 .000000f }, 5 .0f , 5 .0f );
201
275
276
+ run_seqrep_tests ();
277
+
202
278
printf (" OK\n " );
203
279
204
280
return 0 ;
0 commit comments