Skip to content

Commit 75c37ed

Browse files
committed
fixed bug in dry sampler
1 parent 99b7760 commit 75c37ed

File tree

2 files changed

+66
-40
lines changed

2 files changed

+66
-40
lines changed

llama.cpp

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12832,60 +12832,86 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1283212832
}
1283312833
}
1283412834

12835-
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836-
// loop through each candidate
12837-
for (size_t i = 0; i < candidates->size; ++i) {
12835+
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836+
// sanity check
12837+
GGML_ASSERT(last_tokens_size > 0);
12838+
12839+
// get the last token
12840+
auto last_token = last_tokens[last_tokens_size - 1];
12841+
12842+
// if last token is part of the sequence breakers, skip whole sampler
12843+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
12844+
return;
12845+
}
1283812846

12839-
// if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
12840-
if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
12847+
// create an unordered map of "next tokens" <-> max match length
12848+
std::unordered_map<llama_token, size_t> match_lengths;
12849+
12850+
// loop through each previous token (exclude the last token)
12851+
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
12852+
// skip if the compare token if it's not the same as the last token
12853+
if(last_tokens[i] != last_token) {
1284112854
continue;
1284212855
}
1284312856

12844-
int max_match_length = 0;
12857+
// get the next token (i + 1 is always less than last_tokens_size)
12858+
auto next_token = last_tokens[i + 1];
1284512859

12846-
// loop through each previous token
12847-
for (size_t j = 0; j < last_token_size; ++j) {
12848-
// if the current candidate is the same as the previous token
12849-
if (candidates->data[i].id == last_tokens[j]) {
12850-
// greedily match sequence backwards starting from the current position with the end of prev
12851-
int match_length = 1;
12860+
// try to extend the match backwards (match length starts a 1 because last token is already matched)
12861+
size_t match_length = 1;
1285212862

12853-
// loop through the previous tokens
12854-
for(;; match_length++) {
12855-
// if we have reached the start of our stored prev, break
12856-
if(j - match_length > 0) break;
12863+
// loop through the previous tokens
12864+
for(;; match_length++) {
12865+
// if we have reached the start of our last tokens, break
12866+
if(i < match_length) break;
1285712867

12858-
// this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
12859-
// but let's check here to avoid the unexpected
12860-
if(last_token_size - match_length < 0) break;
12868+
// compare token starts at our prev index, going backwards by match length
12869+
auto compare_token = last_tokens[i - match_length];
1286112870

12862-
// compare token starts at our prev index, going backwards by match length
12863-
auto compare_token = last_tokens[j - match_length];
12871+
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
12872+
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
1286412873

12865-
// head token starts at the end of prev, going backwards by match length
12866-
auto head_token = last_tokens[last_token_size - match_length];
12874+
// if compare token is part of the sequence breakers, break out of the match
12875+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12876+
break;
1286712877

12868-
// if compare token is part of the sequence breakers, break out of the match
12869-
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12870-
break;
12878+
// break out of the match if any tokens don't match
12879+
if(compare_token != head_token)
12880+
break;
12881+
}
1287112882

12872-
// break out of the match if any tokens don't match
12873-
if(compare_token != head_token)
12874-
break;
12875-
}
12883+
// Check if the next token exists in the map
12884+
auto it = match_lengths.find(next_token);
1287612885

12877-
// update our max match length
12878-
max_match_length = std::max(max_match_length, match_length);
12879-
}
12886+
if (it == match_lengths.end()) {
12887+
// Key does not exist, insert the new value
12888+
match_lengths[next_token] = match_length;
12889+
} else {
12890+
// Key exists, update it with the max of the new value or the existing value
12891+
it->second = std::max(it->second, match_length);
1288012892
}
12893+
}
1288112894

12882-
// apply penalties
12883-
if(max_match_length > dry_allowed_length) {
12884-
// calculate the penalty
12885-
float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
12895+
// apply penalties
12896+
for (const auto& pair : match_lengths) {
12897+
auto next_token = pair.first;
12898+
auto match_length = pair.second;
1288612899

12887-
// apply the dry penalty
12888-
candidates->data[i].logit -= penalty;
12900+
// if the match length is greater than our allowed length in config, we apply penalities
12901+
if(match_length > dry_allowed_length) {
12902+
12903+
// find our next token in the candidates->data
12904+
size_t i = 0;
12905+
for (; i < candidates->size; ++i) {
12906+
if (candidates->data[i].id == next_token) {
12907+
// calculate the penalty
12908+
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
12909+
12910+
// apply the dry penalty
12911+
candidates->data[i].logit -= penalty;
12912+
break;
12913+
}
12914+
}
1288912915
}
1289012916
}
1289112917
}

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ extern "C" {
923923
struct llama_context * ctx,
924924
llama_token_data_array * candidates,
925925
const llama_token * last_tokens,
926-
int last_token_size,
926+
int last_tokens_size,
927927
float dry_base,
928928
float dry_multiplier,
929929
int dry_allowed_length,

0 commit comments

Comments
 (0)