diff --git a/common/sampling.cpp b/common/sampling.cpp index f2466550168a7..ad6dba83da48b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const int dry_allowed_length = params.dry_allowed_length; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -309,10 +314,19 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + // repetition penalties llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // DRY penalties (multiplier > 0 means enabled) + if(dry_multiplier > 0.0f) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + } + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index cf7081e3674f1..bfc338ef70f1e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,6 +41,9 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + int dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -61,6 +64,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; + std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.cpp b/llama.cpp index 3a84b4916bd30..bb5aff46f800d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // sanity check + GGML_ASSERT(last_tokens_size > 0); + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token if it's not the same as the last token + if(last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // try to extend the match backwards (match length starts a 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our last tokens, break + if(i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; + + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than our allowed length in config, we apply penalities + if(match_length > dry_allowed_length) { + + // find our next token in the candidates->data + size_t i = 0; + for (; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index 0eb2a1e9ab0a2..0c6b86c16323c 100644 --- a/llama.h +++ b/llama.h @@ -924,6 +924,18 @@ extern "C" { float p, size_t min_keep); + /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + int last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * seq_breakers, + int seq_breakers_size); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx,