Skip to content

Commit 3bf3a96

Browse files
Tests
1 parent 416f491 commit 3bf3a96

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tests/test-sampling.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include <vector>
88
#include <algorithm>
99

10+
#undef assert
11+
#define assert(__expr) do { if (!(__expr)) { printf("%s:%d (%s) %s\n", __FILE__, __LINE__, __func__, #__expr); exit(1); } } while(0)
12+
1013
void dump(const llama_token_data_array * candidates) {
1114
for (size_t i = 0; i < candidates->size; i++) {
1215
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
@@ -53,13 +56,14 @@ void test_top_p(const std::vector<float> & probs,
5356
}
5457

5558
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
59+
llama_sample_softmax(nullptr, &candidates_p);
5660
// DUMP(&candidates_p);
5761
llama_sample_top_p(nullptr, &candidates_p, p);
5862
// DUMP(&candidates_p);
5963

6064
assert(candidates_p.size == expected_probs.size());
6165
for (size_t i = 0; i < candidates_p.size; i++) {
62-
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
66+
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
6367
}
6468
}
6569

@@ -82,7 +86,7 @@ void test_tfs(const std::vector<float> & probs,
8286

8387
assert(candidates_p.size == expected_probs.size());
8488
for (size_t i = 0; i < candidates_p.size; i++) {
85-
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
89+
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
8690
}
8791
}
8892

@@ -105,7 +109,7 @@ void test_typical(const std::vector<float> & probs,
105109

106110
assert(candidates_p.size == expected_probs.size());
107111
for (size_t i = 0; i < candidates_p.size; i++) {
108-
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
112+
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
109113
}
110114
}
111115

@@ -163,7 +167,7 @@ void test_frequency_presence_penalty(
163167

164168
assert(candidates_p.size == expected_probs.size());
165169
for (size_t i = 0; i < candidates_p.size; i++) {
166-
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
170+
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
167171
}
168172
}
169173

@@ -182,9 +186,9 @@ int main(void) {
182186
test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5);
183187
test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5);
184188

185-
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0);
186-
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0);
187-
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0);
189+
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0);
190+
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0);
191+
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0);
188192

189193
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0);
190194
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0);

0 commit comments

Comments
 (0)