Skip to content

Commit 4d603e3

Browse files
committed
added DRY implementation
1 parent aea4ad0 commit 4d603e3

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

llama.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1323313233
}
1323413234
}
1323513235

13236+
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
13237+
// sanity check
13238+
GGML_ASSERT(last_tokens_size > 0);
13239+
13240+
// get the last token
13241+
auto last_token = last_tokens[last_tokens_size - 1];
13242+
13243+
// if last token is part of the sequence breakers, skip whole sampler
13244+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
13245+
return;
13246+
}
13247+
13248+
// create an unordered map of "next tokens" <-> max match length
13249+
std::unordered_map<llama_token, size_t> match_lengths;
13250+
13251+
// loop through each previous token (exclude the last token)
13252+
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
13253+
// skip if the compare token if it's not the same as the last token
13254+
if(last_tokens[i] != last_token) {
13255+
continue;
13256+
}
13257+
13258+
// get the next token (i + 1 is always less than last_tokens_size)
13259+
auto next_token = last_tokens[i + 1];
13260+
13261+
// try to extend the match backwards (match length starts a 1 because last token is already matched)
13262+
size_t match_length = 1;
13263+
13264+
// loop through the previous tokens
13265+
for(;; match_length++) {
13266+
// if we have reached the start of our last tokens, break
13267+
if(i < match_length) break;
13268+
13269+
// compare token starts at our prev index, going backwards by match length
13270+
auto compare_token = last_tokens[i - match_length];
13271+
13272+
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
13273+
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
13274+
13275+
// if compare token is part of the sequence breakers, break out of the match
13276+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13277+
break;
13278+
13279+
// break out of the match if any tokens don't match
13280+
if(compare_token != head_token)
13281+
break;
13282+
}
13283+
13284+
// Check if the next token exists in the map
13285+
auto it = match_lengths.find(next_token);
13286+
13287+
if (it == match_lengths.end()) {
13288+
// Key does not exist, insert the new value
13289+
match_lengths[next_token] = match_length;
13290+
} else {
13291+
// Key exists, update it with the max of the new value or the existing value
13292+
it->second = std::max(it->second, match_length);
13293+
}
13294+
}
13295+
13296+
// apply penalties
13297+
for (const auto& pair : match_lengths) {
13298+
auto next_token = pair.first;
13299+
auto match_length = pair.second;
13300+
13301+
// if the match length is greater than our allowed length in config, we apply penalities
13302+
if(match_length > dry_allowed_length) {
13303+
13304+
// find our next token in the candidates->data
13305+
size_t i = 0;
13306+
for (; i < candidates->size; ++i) {
13307+
if (candidates->data[i].id == next_token) {
13308+
// calculate the penalty
13309+
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
13310+
13311+
// apply the dry penalty
13312+
candidates->data[i].logit -= penalty;
13313+
break;
13314+
}
13315+
}
13316+
}
13317+
}
13318+
}
13319+
1323613320
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
1323713321
if (z >= 1.0f || candidates->size <= 2) {
1323813322
return;

0 commit comments

Comments
 (0)