Skip to content

Commit 38d218d

Browse files
committed
seqrep: Add some tests.
1 parent cbfb2ef commit 38d218d

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tests/test-sampling.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#undef NDEBUG
66
#endif
77

8+
#include <cstring>
89
#include <cmath>
910
#include <numeric>
1011
#include <cassert>
@@ -173,6 +174,79 @@ void test_frequency_presence_penalty(
173174
}
174175
}
175176

177+
178+
// NOTE: Compares expected_probs at id position, not sorted position like the other
179+
// test functions.
180+
void test_seqrep_penalty(
181+
const std::vector<float> & probs,
182+
const std::vector<llama_token> & last_tokens,
183+
const std::vector<float> & expected_probs,
184+
const llama_sampler_seqrep_params * params) {
185+
assert(probs.size() == expected_probs.size());
186+
187+
size_t n_vocab = probs.size();
188+
std::vector<llama_token_data> candidates;
189+
candidates.reserve(n_vocab);
190+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
191+
float logit = log(probs[token_id]);
192+
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
193+
}
194+
195+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
196+
llama_sample_softmax(nullptr, &candidates_p);
197+
DUMP(&candidates_p);
198+
llama_sample_seqrep_penalty(nullptr, &candidates_p, (llama_token *) last_tokens.data(), last_tokens.size(), params);
199+
llama_sample_softmax(nullptr, &candidates_p);
200+
DUMP(&candidates_p);
201+
202+
assert(candidates_p.size == expected_probs.size());
203+
for (size_t i = 0; i < candidates_p.size; i++) {
204+
assert(fabs(candidates_p.data[i].p - expected_probs[candidates_p.data[i].id]) < 1e-3);
205+
}
206+
}
207+
208+
void run_seqrep_tests(void) {
209+
llama_sampler_seqrep_params params;
210+
211+
// Compatible with frequency/presence penalty
212+
memset(&params, 0, sizeof(llama_sampler_seqrep_params));
213+
params.last_n = 1024;
214+
params.min_length = 1;
215+
params.mid_word_scale = 1.0f;
216+
params.presence_penalty = 5.0f;
217+
params.length_penalty = 5.0f;
218+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, &params);
219+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, &params);
220+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, &params);
221+
222+
// Compatible with repetition penalty
223+
memset(&params, 0, sizeof(llama_sampler_seqrep_params));
224+
params.last_n = 1024;
225+
params.min_length = 1;
226+
params.mid_word_scale = 1.0f;
227+
params.presence_penalty = 50.0f;
228+
params.length_penalty = 1.0f;
229+
params.flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
230+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, &params);
231+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, &params);
232+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, &params);
233+
234+
// Seqrep mode
235+
memset(&params, 0, sizeof(llama_sampler_seqrep_params));
236+
params.last_n = 1024;
237+
params.min_length = 3;
238+
params.mid_word_scale = 1.0f;
239+
params.tolerance_half_step_cost = 1.0f;
240+
params.presence_penalty = 50.0f;
241+
params.length_penalty = 1.0f;
242+
params.flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
243+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, &params);
244+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.20f, 0.20f, 0.20f, 0.20f, 0.20f}, &params);
245+
params.tolerance = 1.0f;
246+
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, &params);
247+
}
248+
249+
176250
int main(void) {
177251
ggml_time_init();
178252

@@ -199,6 +273,8 @@ int main(void) {
199273
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
200274
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
201275

276+
run_seqrep_tests();
277+
202278
printf("OK\n");
203279

204280
return 0;

0 commit comments

Comments
 (0)