Skip to content

Commit 6c4c88d

Browse files
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
1 parent 61f822f commit 6c4c88d

File tree

6 files changed

+70
-99
lines changed

6 files changed

+70
-99
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
7676
# Compile flags
7777
#
7878

79-
set(CMAKE_CXX_STANDARD 20)
79+
set(CMAKE_CXX_STANDARD 11)
8080
set(CMAKE_CXX_STANDARD_REQUIRED true)
8181
set(CMAKE_C_STANDARD 11)
8282
set(CMAKE_C_STANDARD_REQUIRED true)

examples/common.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
158158
break;
159159
}
160160
params.mirostat = std::stoi(argv[i]);
161-
} else if (arg == "--mirostat_eta") {
161+
} else if (arg == "--mirostat_lr") {
162162
if (++i >= argc) {
163163
invalid_param = true;
164164
break;
165165
}
166166
params.mirostat_eta = std::stof(argv[i]);
167-
} else if (arg == "--mirostat_tau") {
167+
} else if (arg == "--mirostat_ent") {
168168
if (++i >= argc) {
169169
invalid_param = true;
170170
break;
@@ -242,7 +242,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
242242
char sign;
243243
std::string value_str;
244244
try {
245-
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) {
245+
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
246246
params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
247247
} else {
248248
throw std::exception();
@@ -309,18 +309,21 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
309309
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
310310
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
311311
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
312-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
312+
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
313313
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
314314
fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
315315
fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
316-
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat);
317-
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
318-
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
319-
fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS");
316+
fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
317+
fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
318+
fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
319+
fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
320+
fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
321+
fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
320322
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
321-
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n");
323+
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
324+
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
322325
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
323-
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n");
326+
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
324327
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
325328
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
326329
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);

examples/main/main.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ int main(int argc, char ** argv) {
276276
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
277277
}
278278
}
279-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
279+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
280280
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
281281
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
282282
fprintf(stderr, "\n\n");
@@ -420,8 +420,8 @@ int main(int argc, char ** argv) {
420420

421421
std::vector<llama_token_data> candidates;
422422
candidates.reserve(n_vocab);
423-
for (size_t i = 0; i < (size_t) n_vocab; i++) {
424-
candidates.emplace_back(i, logits[i], 0.0f);
423+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
424+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
425425
}
426426

427427
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
@@ -445,11 +445,12 @@ int main(int argc, char ** argv) {
445445
} else {
446446
if (mirostat == 1) {
447447
static float mirostat_mu = 2.0f * mirostat_tau;
448-
static int mirostat_k = 40;
449448
const int mirostat_m = 100;
450-
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu);
449+
llama_sample_temperature(ctx, &candidates_p, temp);
450+
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
451451
} else if (mirostat == 2) {
452452
static float mirostat_mu = 2.0f * mirostat_tau;
453+
llama_sample_temperature(ctx, &candidates_p, temp);
453454
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
454455
} else {
455456
// Temperature sampling

llama.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,12 +1710,6 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat
17101710
} else {
17111711
candidates->data[i].logit /= penalty;
17121712
}
1713-
1714-
// But it does not penalize tokens that logits are near zero, which is a problem.
1715-
// Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits.
1716-
// float probability = std::exp(candidates[i].logit);
1717-
// probability /= penalty;
1718-
// candidates[i].logit = std::log(probability);
17191713
}
17201714

17211715
candidates->sorted = false;
@@ -1757,9 +1751,9 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
17571751
}
17581752

17591753

1760-
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu) {
1754+
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
17611755
assert(ctx);
1762-
1756+
auto N = float(llama_n_vocab(ctx));
17631757
int64_t t_start_sample_us;
17641758
t_start_sample_us = ggml_time_us();
17651759

@@ -1779,12 +1773,10 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
17791773

17801774
// Compute k from the estimated s_hat and target surprise value
17811775
float epsilon_hat = s_hat - 1;
1782-
float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
1783-
*k = int(std::min(new_k, float(candidates->size)));
1776+
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
17841777

17851778
// Sample the next word X using top-k sampling
1786-
// printf("llama_sample_mirostat *k = %d\n", *k);
1787-
llama_sample_top_k(nullptr, candidates, *k);
1779+
llama_sample_top_k(nullptr, candidates, int(k));
17881780
if (ctx) {
17891781
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
17901782
}

llama.h

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,37 +189,47 @@ extern "C" {
189189

190190
// Sampling functions
191191

192-
/// @brief Repetition penalty
193-
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/pdf/1909.05858.pdf with negative logit fix
192+
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
194193
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
195-
/// @brief Frequency and presence repetition penalties
196-
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details
194+
195+
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
197196
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
198197

198+
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
199199
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
200+
201+
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
200202
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
203+
204+
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
201205
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
202206

203-
/// @brief Tail Free Sampling https://www.trentonbricken.com/Tail-Free-Sampling/
207+
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
204208
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
205209

206-
/// @brief Locally Typical Sampling https://arxiv.org/pdf/2202.00666.pdf
210+
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
207211
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
208212
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
209213

210-
/// @brief Mirostat implementation.
211214
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
212-
/// @param ctx The llama context.
213215
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
214216
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
215217
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
216218
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
217-
/// @param N The size of the vocabulary. This is used in the calculation of the `k` value.
218-
/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling.
219-
/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
220-
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
219+
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
220+
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
221+
222+
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
223+
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
224+
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
225+
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
226+
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
221227
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
228+
229+
/// @details Selects the token with the highest probability.
222230
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
231+
232+
/// @details Randomly selects a token from the candidates based on their probabilities.
223233
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
224234

225235
// Performance information

0 commit comments

Comments
 (0)