Skip to content

Commit 9847a37

Browse files
committed
params : allow penalty_last_n == -1 to be equal to context size
ggml-ci
1 parent a04a5b5 commit 9847a37

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

common/arg.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
909909
{"--repeat-last-n"}, "N",
910910
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
911911
[](common_params & params, int value) {
912+
if (value < -1) {
913+
throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
914+
}
912915
params.sampling.penalty_last_n = value;
913916
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
914917
}
@@ -963,6 +966,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
963966
{"--dry-penalty-last-n"}, "N",
964967
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
965968
[](common_params & params, int value) {
969+
if (value < -1) {
970+
throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
971+
}
966972
params.sampling.dry_penalty_last_n = value;
967973
}
968974
).set_sparam());

common/common.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,16 @@ struct common_init_result common_init_from_params(common_params & params) {
945945
params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
946946
}
947947

948+
if (params.sampling.penalty_last_n == -1) {
949+
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
950+
params.sampling.penalty_last_n = llama_n_ctx(lctx);
951+
}
952+
953+
if (params.sampling.dry_penalty_last_n == -1) {
954+
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
955+
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
956+
}
957+
948958
if (params.warmup) {
949959
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
950960

examples/server/server.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ struct server_task {
183183

184184
static slot_params params_from_json_cmpl(
185185
const llama_model * model,
186+
const llama_context * ctx,
186187
const common_params & params_base,
187188
const json & data) {
188189
slot_params params;
@@ -237,8 +238,27 @@ struct server_task {
237238
params.speculative.n_min = std::max(params.speculative.n_min, 2);
238239
params.speculative.n_max = std::max(params.speculative.n_max, 0);
239240

241+
// TODO: add more sanity checks for the input parameters
242+
243+
if (params.sampling.penalty_last_n < -1) {
244+
throw std::runtime_error("Error: repeat_last_n must be >= -1");
245+
}
246+
247+
if (params.sampling.dry_penalty_last_n < -1) {
248+
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
249+
}
250+
251+
if (params.sampling.penalty_last_n == -1) {
252+
// note: should be the slot's context and not the full context, but it's ok
253+
params.sampling.penalty_last_n = llama_n_ctx(ctx);
254+
}
255+
256+
if (params.sampling.dry_penalty_last_n == -1) {
257+
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
258+
}
259+
240260
if (params.sampling.dry_base < 1.0f) {
241-
params.sampling.dry_base = defaults.sampling.dry_base;
261+
params.sampling.dry_base = defaults.sampling.dry_base;
242262
}
243263

244264
// sequence breakers for DRY
@@ -3379,7 +3399,7 @@ int main(int argc, char ** argv) {
33793399
task.index = i;
33803400

33813401
task.prompt_tokens = std::move(tokenized_prompts[i]);
3382-
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
3402+
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
33833403
task.id_selected_slot = json_value(data, "id_slot", -1);
33843404

33853405
// OAI-compat

0 commit comments

Comments
 (0)