Skip to content

Commit 33a69ec

Browse files
committed
tests : replace macros with functions
ggml-ci
1 parent e31c879 commit 33a69ec

File tree

1 file changed

+108
-98
lines changed

1 file changed

+108
-98
lines changed

tests/test-sampling.cpp

Lines changed: 108 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,155 +18,165 @@ static void dump(const llama_token_data_array * cur_p) {
1818

1919
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
2020

21-
#define APPLY(__cnstr, __cur_p) do { \
22-
auto * cnstr = (__cnstr); \
23-
llama_sampler_apply(cnstr, (__cur_p)); \
24-
llama_sampler_free(cnstr); \
25-
} while(0)
26-
27-
#define CUR_P_FROM_PROBS() \
28-
const size_t n_vocab = probs.size(); \
29-
std::vector<llama_token_data> cur; \
30-
cur.reserve(n_vocab); \
31-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { \
32-
const float logit = logf(probs[token_id]); \
33-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); \
34-
} \
35-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }
36-
37-
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
38-
CUR_P_FROM_PROBS();
39-
40-
DUMP(&cur_p);
41-
APPLY(llama_sampler_init_top_k(k), &cur_p);
42-
APPLY(llama_sampler_init_dist (0), &cur_p);
43-
DUMP(&cur_p);
44-
45-
GGML_ASSERT(cur_p.size == expected_probs.size());
46-
for (size_t i = 0; i < cur_p.size; i++) {
47-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
21+
struct sampler_tester {
22+
sampler_tester(size_t n_vocab) {
23+
cur.reserve(n_vocab);
24+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
25+
const float logit = logf(token_id);
26+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
27+
}
28+
29+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
4830
}
49-
}
5031

51-
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
52-
CUR_P_FROM_PROBS();
32+
sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
33+
cur.reserve(probs.size());
34+
for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
35+
const float logit = logf(probs[token_id]);
36+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
37+
}
5338

54-
DUMP(&cur_p);
55-
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
56-
APPLY(llama_sampler_init_dist (0), &cur_p);
57-
DUMP(&cur_p);
58-
DUMP(&cur_p);
39+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
40+
}
5941

60-
GGML_ASSERT(cur_p.size == expected_probs.size());
61-
for (size_t i = 0; i < cur_p.size; i++) {
62-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
42+
void apply(llama_sampler * sampler) {
43+
llama_sampler_apply(sampler, &cur_p);
44+
llama_sampler_free(sampler);
6345
}
46+
47+
void check() {
48+
GGML_ASSERT(cur_p.size == probs_expected.size());
49+
for (size_t i = 0; i < cur_p.size; i++) {
50+
GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
51+
}
52+
}
53+
54+
llama_token_data_array cur_p;
55+
56+
private:
57+
const std::vector<float> probs_expected;
58+
59+
std::vector<llama_token_data> cur;
60+
};
61+
62+
static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
63+
sampler_tester tester(probs, probs_expected);
64+
65+
DUMP(&tester.cur_p);
66+
tester.apply(llama_sampler_init_temp(temp));
67+
tester.apply(llama_sampler_init_dist(0));
68+
DUMP(&tester.cur_p);
69+
70+
tester.check();
6471
}
6572

66-
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
67-
CUR_P_FROM_PROBS();
73+
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
74+
sampler_tester tester(probs, probs_expected);
6875

69-
DUMP(&cur_p);
70-
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
71-
DUMP(&cur_p);
76+
DUMP(&tester.cur_p);
77+
tester.apply(llama_sampler_init_top_k(k));
78+
tester.apply(llama_sampler_init_dist (0));
79+
DUMP(&tester.cur_p);
7280

73-
GGML_ASSERT(cur_p.size == expected_probs.size());
74-
for (size_t i = 0; i < cur_p.size; i++) {
75-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
76-
}
81+
tester.check();
7782
}
7883

79-
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
80-
CUR_P_FROM_PROBS();
84+
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
85+
sampler_tester tester(probs, probs_expected);
8186

82-
DUMP(&cur_p);
83-
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
84-
APPLY(llama_sampler_init_dist (0), &cur_p);
85-
DUMP(&cur_p);
87+
DUMP(&tester.cur_p);
88+
tester.apply(llama_sampler_init_top_p(p, 1));
89+
tester.apply(llama_sampler_init_dist (0));
90+
DUMP(&tester.cur_p);
8691

87-
GGML_ASSERT(cur_p.size == expected_probs.size());
88-
for (size_t i = 0; i < cur_p.size; i++) {
89-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
90-
}
92+
tester.check();
9193
}
9294

93-
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
94-
CUR_P_FROM_PROBS();
95+
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
96+
sampler_tester tester(probs, probs_expected);
9597

96-
DUMP(&cur_p);
97-
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
98-
DUMP(&cur_p);
98+
DUMP(&tester.cur_p);
99+
tester.apply(llama_sampler_init_tail_free(z, 1));
100+
DUMP(&tester.cur_p);
99101

100-
GGML_ASSERT(cur_p.size == expected_probs.size());
101-
for (size_t i = 0; i < cur_p.size; i++) {
102-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
103-
}
102+
tester.check();
104103
}
105104

106-
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
107-
CUR_P_FROM_PROBS();
105+
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
106+
sampler_tester tester(probs, probs_expected);
108107

109-
DUMP(&cur_p);
110-
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
111-
DUMP(&cur_p);
108+
DUMP(&tester.cur_p);
109+
tester.apply(llama_sampler_init_min_p(p, 1));
110+
tester.apply(llama_sampler_init_dist (0));
111+
DUMP(&tester.cur_p);
112112

113-
GGML_ASSERT(cur_p.size == expected_probs.size());
114-
for (size_t i = 0; i < cur_p.size; i++) {
115-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
116-
}
113+
tester.check();
114+
}
115+
116+
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
117+
sampler_tester tester(probs, probs_expected);
118+
119+
DUMP(&tester.cur_p);
120+
tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
121+
DUMP(&tester.cur_p);
122+
123+
tester.check();
124+
}
125+
126+
static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
127+
sampler_tester tester(probs, probs_expected);
128+
129+
DUMP(&tester.cur_p);
130+
tester.apply(llama_sampler_init_typical(p, 1));
131+
DUMP(&tester.cur_p);
132+
133+
tester.check();
117134
}
118135

119136
static void test_penalties(
120137
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
121-
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
138+
const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
122139
) {
123-
GGML_ASSERT(probs.size() == expected_probs.size());
140+
GGML_ASSERT(probs.size() == probs_expected.size());
124141

125-
CUR_P_FROM_PROBS();
142+
sampler_tester tester(probs, probs_expected);
126143

144+
const size_t n_vocab = probs.size();
127145
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
128146

129147
for (size_t i = 0; i < last_tokens.size(); i++) {
130148
llama_sampler_accept(sampler, last_tokens[i]);
131149
}
132150

133-
DUMP(&cur_p);
134-
APPLY(sampler, &cur_p);
135-
APPLY(llama_sampler_init_dist(0), &cur_p);
136-
DUMP(&cur_p);
151+
DUMP(&tester.cur_p);
152+
tester.apply(sampler);
153+
tester.apply(llama_sampler_init_dist(0));
154+
DUMP(&tester.cur_p);
137155

138-
GGML_ASSERT(cur_p.size == expected_probs.size());
139-
for (size_t i = 0; i < cur_p.size; i++) {
140-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
141-
}
156+
tester.check();
142157
}
143158

144159
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
145160
) {
146-
std::vector<llama_token_data> cur;
147-
cur.reserve(n_vocab);
148-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
149-
const float logit = logf(token_id);
150-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
151-
}
152-
153-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
161+
sampler_tester tester(n_vocab);
154162

155163
llama_token min_token_id = 0;
156164
const llama_token max_token_id = n_vocab-1;
157165

158166
for (auto s : samplers_sequence) {
159167
switch (s){
160-
case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
168+
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
161169
case 'f': GGML_ABORT("tail_free test not implemented");
162170
case 'y': GGML_ABORT("typical test not implemented");
163-
case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
164-
case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
171+
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
172+
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
165173
case 't': GGML_ABORT("temperature test not implemented");
166174
default : GGML_ABORT("Unknown sampler");
167175
}
168176

169-
APPLY(llama_sampler_init_dist(0), &cur_p);
177+
tester.apply(llama_sampler_init_dist(0));
178+
179+
auto & cur_p = tester.cur_p;
170180

171181
const int size = cur_p.size;
172182

0 commit comments

Comments
 (0)