Skip to content

Commit d1676a1

Browse files
authored
Merge pull request #29 from wwoodsTM/test-dry-sampler
Add DRY sampling parameters to gpt_params and server_context
2 parents e862def + 20dc562 commit d1676a1

File tree

2 files changed

+79
-24
lines changed

2 files changed

+79
-24
lines changed

common/common.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
555555
sparams.penalty_present = std::stof(argv[i]);
556556
return true;
557557
}
558+
if (arg == "--dry-multiplier") {
559+
CHECK_ARG
560+
sparams.dry_multiplier = std::stof(argv[i]);
561+
return true;
562+
}
563+
if (arg == "--dry-base") {
564+
CHECK_ARG
565+
sparams.dry_base = std::stof(argv[i]);
566+
return true;
567+
}
568+
if (arg == "--dry-allowed-length") {
569+
CHECK_ARG
570+
sparams.dry_allowed_length = std::stoi(argv[i]);
571+
return true;
572+
}
573+
if (arg == "--dry-penalty-last-n") {
574+
CHECK_ARG
575+
sparams.dry_penalty_last_n = std::stoi(argv[i]);
576+
return true;
577+
}
558578
if (arg == "--dynatemp-range") {
559579
CHECK_ARG
560580
sparams.dynatemp_range = std::stof(argv[i]);
@@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14711491
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
14721492
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
14731493
options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq });
1494+
options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier });
1495+
options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base });
1496+
options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length });
1497+
options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n });
1498+
14741499
options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range });
14751500
options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent });
14761501
options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n"

examples/server/server.cpp

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -898,30 +898,55 @@ struct server_context {
898898
slot.oaicompat_model = "";
899899
}
900900

901-
slot.params.stream = json_value(data, "stream", false);
902-
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
903-
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
904-
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
905-
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
906-
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
907-
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
908-
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
909-
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
910-
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
911-
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
912-
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
913-
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
914-
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
915-
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
916-
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
917-
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
918-
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
919-
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
920-
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
921-
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
922-
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
923-
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
924-
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
901+
slot.params.stream = json_value(data, "stream", false);
902+
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
903+
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
904+
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
905+
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
906+
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
907+
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
908+
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
909+
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
910+
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
911+
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
912+
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
913+
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
914+
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
915+
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
916+
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
917+
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
918+
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
919+
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
920+
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
921+
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
922+
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
923+
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
924+
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
925+
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
926+
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
927+
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
928+
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
929+
930+
// sequence breakers for DRY
931+
{
932+
auto dry_seq_breakers = data.find("dry_seq_breakers");
933+
if (dry_seq_breakers != data.end()) {
934+
try {
935+
if (dry_seq_breakers->is_array()) {
936+
slot.sparams.dry_seq_breakers = dry_seq_breakers->get<std::vector<std::string>>();
937+
} else if (dry_seq_breakers->is_string()) {
938+
slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get<std::string>()).get<std::vector<std::string>>();
939+
} else {
940+
send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST);
941+
return false;
942+
}
943+
} catch (const std::exception & e) {
944+
send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
945+
return false;
946+
}
947+
}
948+
}
949+
925950

926951
// process "json_schema" and "grammar"
927952
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@@ -1339,6 +1364,11 @@ struct server_context {
13391364
{"frequency_penalty", slot.sparams.penalty_freq},
13401365
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
13411366
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
1367+
{"dry_multiplier", slot.sparams.dry_multiplier},
1368+
{"dry_base", slot.sparams.dry_base},
1369+
{"dry_allowed_length", slot.sparams.dry_allowed_length},
1370+
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
1371+
{"dry_seq_breakers", slot.sparams.dry_seq_breakers},
13421372
{"mirostat", slot.sparams.mirostat},
13431373
{"mirostat_tau", slot.sparams.mirostat_tau},
13441374
{"mirostat_eta", slot.sparams.mirostat_eta},

0 commit comments

Comments
 (0)