Skip to content

Commit dd7eff5

Browse files
llama : new sampling algorithms (#1126)
* Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. * mirostat * Added --logit-bias and --no-penalize-nl, removed std::span * Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) * Save and load example adjust * Tests * Windows build fix * Windows test fix
1 parent 7fc50c0 commit dd7eff5

File tree

8 files changed

+808
-156
lines changed

8 files changed

+808
-156
lines changed

examples/common.cpp

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <string>
77
#include <iterator>
88
#include <algorithm>
9+
#include <sstream>
10+
#include <iostream>
911

1012
#if defined (_WIN32)
1113
#include <fcntl.h>
@@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
114116
break;
115117
}
116118
params.temp = std::stof(argv[i]);
119+
} else if (arg == "--tfs") {
120+
if (++i >= argc) {
121+
invalid_param = true;
122+
break;
123+
}
124+
params.tfs_z = std::stof(argv[i]);
125+
} else if (arg == "--typical") {
126+
if (++i >= argc) {
127+
invalid_param = true;
128+
break;
129+
}
130+
params.typical_p = std::stof(argv[i]);
117131
} else if (arg == "--repeat_last_n") {
118132
if (++i >= argc) {
119133
invalid_param = true;
@@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
126140
break;
127141
}
128142
params.repeat_penalty = std::stof(argv[i]);
143+
} else if (arg == "--frequency_penalty") {
144+
if (++i >= argc) {
145+
invalid_param = true;
146+
break;
147+
}
148+
params.frequency_penalty = std::stof(argv[i]);
149+
} else if (arg == "--presence_penalty") {
150+
if (++i >= argc) {
151+
invalid_param = true;
152+
break;
153+
}
154+
params.presence_penalty = std::stof(argv[i]);
155+
} else if (arg == "--mirostat") {
156+
if (++i >= argc) {
157+
invalid_param = true;
158+
break;
159+
}
160+
params.mirostat = std::stoi(argv[i]);
161+
} else if (arg == "--mirostat_lr") {
162+
if (++i >= argc) {
163+
invalid_param = true;
164+
break;
165+
}
166+
params.mirostat_eta = std::stof(argv[i]);
167+
} else if (arg == "--mirostat_ent") {
168+
if (++i >= argc) {
169+
invalid_param = true;
170+
break;
171+
}
172+
params.mirostat_tau = std::stof(argv[i]);
129173
} else if (arg == "-b" || arg == "--batch_size") {
130174
if (++i >= argc) {
131175
invalid_param = true;
@@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
185229
} else if (arg == "--perplexity") {
186230
params.perplexity = true;
187231
} else if (arg == "--ignore-eos") {
188-
params.ignore_eos = true;
232+
params.logit_bias[llama_token_eos()] = -INFINITY;
233+
} else if (arg == "--no-penalize-nl") {
234+
params.penalize_nl = false;
235+
} else if (arg == "-l" || arg == "--logit-bias") {
236+
if (++i >= argc) {
237+
invalid_param = true;
238+
break;
239+
}
240+
std::stringstream ss(argv[i]);
241+
llama_token key;
242+
char sign;
243+
std::string value_str;
244+
try {
245+
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
246+
params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
247+
} else {
248+
throw std::exception();
249+
}
250+
} catch (const std::exception &e) {
251+
invalid_param = true;
252+
break;
253+
}
189254
} else if (arg == "--n_parts") {
190255
if (++i >= argc) {
191256
invalid_param = true;
@@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
240305
fprintf(stderr, " -f FNAME, --file FNAME\n");
241306
fprintf(stderr, " prompt file to start generation.\n");
242307
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
243-
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
244-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
245-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
246-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
308+
fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
309+
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
310+
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
311+
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
312+
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
313+
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
314+
fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
315+
fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
316+
fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
317+
fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
318+
fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
319+
fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
320+
fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
321+
fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
322+
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
323+
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
324+
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
247325
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
248-
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
326+
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
327+
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
249328
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
250329
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
251330
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");

examples/common.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99
#include <random>
1010
#include <thread>
11+
#include <unordered_map>
1112

1213
//
1314
// CLI argument parsing
@@ -17,17 +18,25 @@ struct gpt_params {
1718
int32_t seed = -1; // RNG seed
1819
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
1920
int32_t n_predict = 128; // new tokens to predict
20-
int32_t repeat_last_n = 64; // last n tokens to penalize
2121
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
2222
int32_t n_ctx = 512; // context size
2323
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
2424
int32_t n_keep = 0; // number of tokens to keep from initial prompt
2525

2626
// sampling parameters
27-
int32_t top_k = 40;
28-
float top_p = 0.95f;
29-
float temp = 0.80f;
30-
float repeat_penalty = 1.10f;
27+
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
28+
int32_t top_k = 0; // <= 0 to use vocab size
29+
float top_p = 1.0f; // 1.0 = disabled
30+
float tfs_z = 1.0f; // 1.0 = disabled
31+
float typical_p = 1.0f; // 1.0 = disabled
32+
float temp = 1.0f; // 1.0 = disabled
33+
float repeat_penalty = 1.0f; // 1.0 = disabled
34+
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
35+
float frequency_penalty = 0.0f; // 0.0 = disabled
36+
float presence_penalty = 0.0f; // 0.0 = disabled
37+
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
38+
float mirostat_tau = 5.0f; // target entropy
39+
float mirostat_eta = 0.1f; // learning rate
3140

3241
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3342
std::string prompt = "";
@@ -47,7 +56,7 @@ struct gpt_params {
4756
bool interactive_first = false; // wait for user input immediately
4857

4958
bool instruct = false; // instruction mode (used for Alpaca models)
50-
bool ignore_eos = false; // do not stop generating after eos
59+
bool penalize_nl = true; // consider newlines as a repeatable token
5160
bool perplexity = false; // compute perplexity over the prompt
5261
bool use_mmap = true; // use mmap for faster loads
5362
bool use_mlock = false; // use mlock to keep model in memory

examples/main/main.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
276276
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
277277
}
278278
}
279-
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
280-
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
279+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
280+
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
281281
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
282282
fprintf(stderr, "\n\n");
283283

@@ -387,10 +387,19 @@ int main(int argc, char ** argv) {
387387

388388
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
389389
// out of user input, sample next token
390-
const int32_t top_k = params.top_k;
391-
const float top_p = params.top_p;
392390
const float temp = params.temp;
391+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
392+
const float top_p = params.top_p;
393+
const float tfs_z = params.tfs_z;
394+
const float typical_p = params.typical_p;
395+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
393396
const float repeat_penalty = params.repeat_penalty;
397+
const float alpha_presence = params.presence_penalty;
398+
const float alpha_frequency = params.frequency_penalty;
399+
const int mirostat = params.mirostat;
400+
const float mirostat_tau = params.mirostat_tau;
401+
const float mirostat_eta = params.mirostat_eta;
402+
const bool penalize_nl = params.penalize_nl;
394403

395404
// optionally save the session on first sample (for faster prompt loading next time)
396405
if (!path_session.empty() && need_to_save_session) {
@@ -402,14 +411,58 @@ int main(int argc, char ** argv) {
402411

403412
{
404413
auto logits = llama_get_logits(ctx);
414+
auto n_vocab = llama_n_vocab(ctx);
405415

406-
if (params.ignore_eos) {
407-
logits[llama_token_eos()] = 0;
416+
// Apply params.logit_bias map
417+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
418+
logits[it->first] += it->second;
408419
}
409420

410-
id = llama_sample_top_p_top_k(ctx,
411-
last_n_tokens.data() + n_ctx - params.repeat_last_n,
412-
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
421+
std::vector<llama_token_data> candidates;
422+
candidates.reserve(n_vocab);
423+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
424+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
425+
}
426+
427+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
428+
429+
// Apply penalties
430+
float nl_logit = logits[llama_token_nl()];
431+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
432+
llama_sample_repetition_penalty(ctx, &candidates_p,
433+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
434+
last_n_repeat, repeat_penalty);
435+
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
436+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
437+
last_n_repeat, alpha_frequency, alpha_presence);
438+
if (!penalize_nl) {
439+
logits[llama_token_nl()] = nl_logit;
440+
}
441+
442+
if (temp <= 0) {
443+
// Greedy sampling
444+
id = llama_sample_token_greedy(ctx, &candidates_p);
445+
} else {
446+
if (mirostat == 1) {
447+
static float mirostat_mu = 2.0f * mirostat_tau;
448+
const int mirostat_m = 100;
449+
llama_sample_temperature(ctx, &candidates_p, temp);
450+
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
451+
} else if (mirostat == 2) {
452+
static float mirostat_mu = 2.0f * mirostat_tau;
453+
llama_sample_temperature(ctx, &candidates_p, temp);
454+
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
455+
} else {
456+
// Temperature sampling
457+
llama_sample_top_k(ctx, &candidates_p, top_k);
458+
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
459+
llama_sample_typical(ctx, &candidates_p, typical_p);
460+
llama_sample_top_p(ctx, &candidates_p, top_p);
461+
llama_sample_temperature(ctx, &candidates_p, temp);
462+
id = llama_sample_token(ctx, &candidates_p);
463+
}
464+
}
465+
// printf("`%d`", candidates_p.size);
413466

414467
last_n_tokens.erase(last_n_tokens.begin());
415468
last_n_tokens.push_back(id);

examples/save-load-state/save-load-state.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
6464
// first run
6565
printf("\n%s", params.prompt.c_str());
6666
for (auto i = 0; i < params.n_predict; i++) {
67-
auto next_token = llama_sample_top_p_top_k(
68-
ctx,
69-
&last_n_tokens_data.back() - params.repeat_last_n,
70-
params.repeat_last_n,
71-
40,
72-
1.0,
73-
1.0,
74-
1.1);
67+
auto logits = llama_get_logits(ctx);
68+
auto n_vocab = llama_n_vocab(ctx);
69+
std::vector<llama_token_data> candidates;
70+
candidates.reserve(n_vocab);
71+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
72+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
73+
}
74+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
75+
auto next_token = llama_sample_token(ctx, &candidates_p);
7576
auto next_token_str = llama_token_to_str(ctx, next_token);
7677
last_n_tokens_data.push_back(next_token);
7778
printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
106107

107108
// second run
108109
for (auto i = 0; i < params.n_predict; i++) {
109-
auto next_token = llama_sample_top_p_top_k(
110-
ctx2,
111-
&last_n_tokens_data.back() - params.repeat_last_n,
112-
params.repeat_last_n,
113-
40,
114-
1.0,
115-
1.0,
116-
1.1);
110+
auto logits = llama_get_logits(ctx2);
111+
auto n_vocab = llama_n_vocab(ctx2);
112+
std::vector<llama_token_data> candidates;
113+
candidates.reserve(n_vocab);
114+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
115+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
116+
}
117+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
118+
auto next_token = llama_sample_token(ctx2, &candidates_p);
117119
auto next_token_str = llama_token_to_str(ctx2, next_token);
118120
last_n_tokens_data.push_back(next_token);
119121
printf("%s", next_token_str);

0 commit comments

Comments
 (0)