From 9105cf435bb0badbf60f7ea8b3437151e2404f9c Mon Sep 17 00:00:00 2001 From: wwoodsTM Date: Mon, 5 Aug 2024 00:03:38 -0600 Subject: [PATCH 1/2] Add DRY sampling parameters to gpt_params and server_context --- common/common.cpp | 25 ++++++++++++ examples/server/server.cpp | 78 ++++++++++++++++++++++++++------------ pr-6839.diff | 0 3 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 pr-6839.diff diff --git a/common/common.cpp b/common/common.cpp index 60c7eac75c613..4cf094b7516cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.penalty_present = std::stof(argv[i]); return true; } + if (arg == "--dry-multiplier") { + CHECK_ARG + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + CHECK_ARG + sparams.dry_base = std::stof(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + CHECK_ARG + sparams.dry_allowed_length = std::stoi(argv[i]); + return true; + } + if (arg == "--dry-penalty-last-n") { + CHECK_ARG + sparams.dry_penalty_last_n = std::stoi(argv[i]); + return true; + } if (arg == "--dynatemp-range") { CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); @@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); + options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier }); + options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base }); + options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length }); + options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n }); + options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7813a2957d6bc..4b2654db9b767 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -898,30 +898,55 @@ struct server_context { slot.oaicompat_model = ""; } - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.stream = json_value(data, "stream", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", false); + slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + // sequence breakers for DRY + { + auto dry_seq_breakers = data.find("dry_seq_breakers"); + if (dry_seq_breakers != data.end()) { + try { + if (dry_seq_breakers->is_array()) { + slot.sparams.dry_seq_breakers = dry_seq_breakers->get>(); + } else if (dry_seq_breakers->is_string()) { + slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get()).get>(); + } else { + send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } catch (const std::exception & e) { + send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -1339,6 +1364,11 @@ struct server_context { {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_seq_breakers", slot.sparams.dry_seq_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, diff --git a/pr-6839.diff b/pr-6839.diff new file mode 100644 index 0000000000000..e69de29bb2d1d From 20dc562f45434b105b3d167830e927c054600e41 Mon Sep 17 00:00:00 2001 From: wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> Date: Mon, 5 Aug 2024 00:41:26 -0600 Subject: [PATCH 2/2] Delete pr-6839.diff --- pr-6839.diff | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 pr-6839.diff diff --git a/pr-6839.diff b/pr-6839.diff deleted file mode 100644 index e69de29bb2d1d..0000000000000