You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
// But it does not penalize tokens that logits are near zero, which is a problem.
1715
-
// Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits.
1716
-
// float probability = std::exp(candidates[i].logit);
1717
-
// probability /= penalty;
1718
-
// candidates[i].logit = std::log(probability);
1719
1713
}
1720
1714
1721
1715
candidates->sorted = false;
@@ -1757,9 +1751,9 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
1757
1751
}
1758
1752
1759
1753
1760
-
llama_token llama_sample_token_mirostat(structllama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, floatN, int * k, float* mu) {
1754
+
llama_token llama_sample_token_mirostat(structllama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
212
-
/// @param ctx The llama context.
213
215
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
214
216
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
215
217
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
216
218
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
217
-
/// @param N The size of the vocabulary. This is used in the calculation of the `k` value.
218
-
/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling.
219
-
/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
220
-
LLAMA_API llama_token llama_sample_token_mirostat(structllama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
219
+
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
220
+
LLAMA_API llama_token llama_sample_token_mirostat(structllama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
221
+
222
+
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
223
+
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
224
+
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
225
+
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
226
+
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
0 commit comments