|
6 | 6 | #include <string>
|
7 | 7 | #include <iterator>
|
8 | 8 | #include <algorithm>
|
| 9 | +#include <sstream> |
| 10 | +#include <iostream> |
9 | 11 |
|
10 | 12 | #if defined (_WIN32)
|
11 | 13 | #include <fcntl.h>
|
@@ -132,18 +134,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
132 | 134 | break;
|
133 | 135 | }
|
134 | 136 | params.repeat_penalty = std::stof(argv[i]);
|
135 |
| - } else if (arg == "--alpha_frequency") { |
| 137 | + } else if (arg == "--frequency_penalty") { |
136 | 138 | if (++i >= argc) {
|
137 | 139 | invalid_param = true;
|
138 | 140 | break;
|
139 | 141 | }
|
140 |
| - params.alpha_frequency = std::stof(argv[i]); |
141 |
| - } else if (arg == "--alpha_presence") { |
| 142 | + params.frequency_penalty = std::stof(argv[i]); |
| 143 | + } else if (arg == "--presence_penalty") { |
142 | 144 | if (++i >= argc) {
|
143 | 145 | invalid_param = true;
|
144 | 146 | break;
|
145 | 147 | }
|
146 |
| - params.alpha_presence = std::stof(argv[i]); |
| 148 | + params.presence_penalty = std::stof(argv[i]); |
147 | 149 | } else if (arg == "--mirostat") {
|
148 | 150 | if (++i >= argc) {
|
149 | 151 | invalid_param = true;
|
@@ -223,7 +225,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
223 | 225 | } else if (arg == "--perplexity") {
|
224 | 226 | params.perplexity = true;
|
225 | 227 | } else if (arg == "--ignore-eos") {
|
226 |
| - params.ignore_eos = true; |
| 228 | + params.logit_bias[llama_token_eos()] = -INFINITY; |
| 229 | + } else if (arg == "--no-penalize-nl") { |
| 230 | + params.penalize_nl = false; |
| 231 | + } else if (arg == "-l" || arg == "--logit-bias") { |
| 232 | + if (++i >= argc) { |
| 233 | + invalid_param = true; |
| 234 | + break; |
| 235 | + } |
| 236 | + std::stringstream ss(argv[i]); |
| 237 | + llama_token key; |
| 238 | + char sign; |
| 239 | + std::string value_str; |
| 240 | + try { |
| 241 | + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) { |
| 242 | + params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); |
| 243 | + } else { |
| 244 | + throw std::exception(); |
| 245 | + } |
| 246 | + } catch (const std::exception &e) { |
| 247 | + invalid_param = true; |
| 248 | + break; |
| 249 | + } |
227 | 250 | } else if (arg == "--n_parts") {
|
228 | 251 | if (++i >= argc) {
|
229 | 252 | invalid_param = true;
|
@@ -277,19 +300,23 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
277 | 300 | fprintf(stderr, " -f FNAME, --file FNAME\n");
|
278 | 301 | fprintf(stderr, " prompt file to start generation.\n");
|
279 | 302 | fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
|
280 |
| - fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k); |
281 |
| - fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p); |
282 |
| - fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z); |
283 |
| - fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p); |
284 |
| - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n); |
285 |
| - fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty); |
286 |
| - fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence); |
287 |
| - fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency); |
288 |
| - fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat); |
| 303 | + fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); |
| 304 | + fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); |
| 305 | + fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); |
| 306 | + fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); |
| 307 | + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); |
| 308 | + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); |
| 309 | + fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); |
| 310 | + fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); |
| 311 | + fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat); |
289 | 312 | fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
|
290 | 313 | fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
|
| 314 | + fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS"); |
| 315 | + fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); |
| 316 | + fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n"); |
291 | 317 | fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
292 |
| - fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); |
| 318 | + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n"); |
| 319 | + fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); |
293 | 320 | fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
|
294 | 321 | fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
|
295 | 322 | fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
|
|
0 commit comments