Skip to content

Commit 21e7ab4

Browse files
Sample interface, new samplers,
ignore EOS should apply -inf to EOS logit. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat
1 parent 4dbbd40 commit 21e7ab4

File tree

7 files changed

+450
-131
lines changed

7 files changed

+450
-131
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
6868
# Compile flags
6969
#
7070

71-
set(CMAKE_CXX_STANDARD 11)
71+
set(CMAKE_CXX_STANDARD 20)
7272
set(CMAKE_CXX_STANDARD_REQUIRED true)
7373
set(CMAKE_C_STANDARD 11)
7474
set(CMAKE_C_STANDARD_REQUIRED true)

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ endif
3232

3333
# keep standard at C11 and C++11
3434
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
35-
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
35+
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
3636
LDFLAGS =
3737

3838
# warnings

examples/common.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
114114
break;
115115
}
116116
params.temp = std::stof(argv[i]);
117+
} else if (arg == "--tfs") {
118+
if (++i >= argc) {
119+
invalid_param = true;
120+
break;
121+
}
122+
params.tfs_z = std::stof(argv[i]);
123+
} else if (arg == "--typical") {
124+
if (++i >= argc) {
125+
invalid_param = true;
126+
break;
127+
}
128+
params.typical_p = std::stof(argv[i]);
117129
} else if (arg == "--repeat_last_n") {
118130
if (++i >= argc) {
119131
invalid_param = true;
@@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
126138
break;
127139
}
128140
params.repeat_penalty = std::stof(argv[i]);
141+
} else if (arg == "--alpha_frequency") {
142+
if (++i >= argc) {
143+
invalid_param = true;
144+
break;
145+
}
146+
params.alpha_frequency = std::stof(argv[i]);
147+
} else if (arg == "--alpha_presence") {
148+
if (++i >= argc) {
149+
invalid_param = true;
150+
break;
151+
}
152+
params.alpha_presence = std::stof(argv[i]);
129153
} else if (arg == "-b" || arg == "--batch_size") {
130154
if (++i >= argc) {
131155
invalid_param = true;
@@ -230,6 +254,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
230254
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
231255
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
232256
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
257+
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
258+
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
259+
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
260+
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
233261
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
234262
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
235263
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);

examples/main/main.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ int main(int argc, char ** argv) {
215215
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
216216
}
217217
}
218-
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
219-
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
218+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
219+
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
220220
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
221221
fprintf(stderr, "\n\n");
222222

@@ -281,23 +281,69 @@ int main(int argc, char ** argv) {
281281

282282
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
283283
// out of user input, sample next token
284-
const int32_t top_k = params.top_k;
285-
const float top_p = params.top_p;
286284
const float temp = params.temp;
285+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
286+
const float top_p = params.top_p;
287+
const float tfs_z = params.tfs_z;
288+
const float typical_p = params.typical_p;
289+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
287290
const float repeat_penalty = params.repeat_penalty;
291+
const float alpha_presence = params.alpha_presence;
292+
const float alpha_frequency = params.alpha_frequency;
288293

289294
llama_token id = 0;
290295

291296
{
292297
auto logits = llama_get_logits(ctx);
298+
auto n_vocab = llama_n_vocab(ctx);
293299

294300
if (params.ignore_eos) {
295-
logits[llama_token_eos()] = 0;
301+
logits[llama_token_eos()] = -INFINITY;
302+
}
303+
304+
std::vector<llama_token_data> candidates;
305+
candidates.reserve(n_vocab);
306+
for (size_t i = 0; i < n_vocab; i++) {
307+
candidates.emplace_back(i, logits[i], 0.0f);
296308
}
297309

298-
id = llama_sample_top_p_top_k(ctx,
299-
last_n_tokens.data() + n_ctx - params.repeat_last_n,
300-
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
310+
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
311+
312+
// Apply penalties
313+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
314+
llama_sample_repetition_penalty(&candidates_p,
315+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
316+
last_n_repeat, repeat_penalty);
317+
llama_sample_frequency_and_presence_penalties(&candidates_p,
318+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
319+
last_n_repeat, alpha_frequency, alpha_presence);
320+
321+
322+
#if 1
323+
if (temp <= 0) {
324+
// Greedy sampling
325+
id = llama_sample_token_greedy(ctx, &candidates_p);
326+
} else {
327+
// Temperature sampling
328+
llama_sample_top_k(&candidates_p, top_k);
329+
llama_sample_tail_free(&candidates_p, tfs_z);
330+
llama_sample_typical(&candidates_p, typical_p);
331+
llama_sample_top_p(&candidates_p, top_p);
332+
333+
llama_sample_temperature(&candidates_p, temp);
334+
// printf("`%d`", candidates_p.size);
335+
id = llama_sample_token(ctx, &candidates_p);
336+
}
337+
#else
338+
const float tau = 5.0f;
339+
static float mu = 2.0f * tau;
340+
static int k = 40;
341+
const float eta = 0.1f;
342+
const int m = 100;
343+
const float N = n_vocab;
344+
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
345+
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
346+
#endif
301347

302348
last_n_tokens.erase(last_n_tokens.begin());
303349
last_n_tokens.push_back(id);

0 commit comments

Comments
 (0)