@@ -12832,60 +12832,86 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
12832
12832
}
12833
12833
}
12834
12834
12835
- void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836
- // loop through each candidate
12837
- for (size_t i = 0; i < candidates->size; ++i) {
12835
+ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836
+ // sanity check
12837
+ GGML_ASSERT(last_tokens_size > 0);
12838
+
12839
+ // get the last token
12840
+ auto last_token = last_tokens[last_tokens_size - 1];
12841
+
12842
+ // if last token is part of the sequence breakers, skip whole sampler
12843
+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
12844
+ return;
12845
+ }
12838
12846
12839
- // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
12840
- if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
12847
+ // create an unordered map of "next tokens" <-> max match length
12848
+ std::unordered_map<llama_token, size_t> match_lengths;
12849
+
12850
+ // loop through each previous token (exclude the last token)
12851
+ for (size_t i = 0; i < last_tokens_size - 1; ++i) {
12852
+ // skip if the compare token if it's not the same as the last token
12853
+ if(last_tokens[i] != last_token) {
12841
12854
continue;
12842
12855
}
12843
12856
12844
- int max_match_length = 0;
12857
+ // get the next token (i + 1 is always less than last_tokens_size)
12858
+ auto next_token = last_tokens[i + 1];
12845
12859
12846
- // loop through each previous token
12847
- for (size_t j = 0; j < last_token_size; ++j) {
12848
- // if the current candidate is the same as the previous token
12849
- if (candidates->data[i].id == last_tokens[j]) {
12850
- // greedily match sequence backwards starting from the current position with the end of prev
12851
- int match_length = 1;
12860
+ // try to extend the match backwards (match length starts a 1 because last token is already matched)
12861
+ size_t match_length = 1;
12852
12862
12853
- // loop through the previous tokens
12854
- for(;; match_length++) {
12855
- // if we have reached the start of our stored prev , break
12856
- if(j - match_length > 0 ) break;
12863
+ // loop through the previous tokens
12864
+ for(;; match_length++) {
12865
+ // if we have reached the start of our last tokens , break
12866
+ if(i < match_length) break;
12857
12867
12858
- // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
12859
- // but let's check here to avoid the unexpected
12860
- if(last_token_size - match_length < 0) break;
12868
+ // compare token starts at our prev index, going backwards by match length
12869
+ auto compare_token = last_tokens[i - match_length];
12861
12870
12862
- // compare token starts at our prev index , going backwards by match length
12863
- auto compare_token = last_tokens[j - match_length];
12871
+ // head token starts at the end of last tokens , going backwards by match length, minus 1 because we start at the last token itself
12872
+ auto head_token = last_tokens[last_tokens_size - 1 - match_length];
12864
12873
12865
- // head token starts at the end of prev, going backwards by match length
12866
- auto head_token = last_tokens[last_token_size - match_length];
12874
+ // if compare token is part of the sequence breakers, break out of the match
12875
+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12876
+ break;
12867
12877
12868
- // if compare token is part of the sequence breakers, break out of the match
12869
- if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12870
- break;
12878
+ // break out of the match if any tokens don't match
12879
+ if(compare_token != head_token)
12880
+ break;
12881
+ }
12871
12882
12872
- // break out of the match if any tokens don't match
12873
- if(compare_token != head_token)
12874
- break;
12875
- }
12883
+ // Check if the next token exists in the map
12884
+ auto it = match_lengths.find(next_token);
12876
12885
12877
- // update our max match length
12878
- max_match_length = std::max(max_match_length, match_length);
12879
- }
12886
+ if (it == match_lengths.end()) {
12887
+ // Key does not exist, insert the new value
12888
+ match_lengths[next_token] = match_length;
12889
+ } else {
12890
+ // Key exists, update it with the max of the new value or the existing value
12891
+ it->second = std::max(it->second, match_length);
12880
12892
}
12893
+ }
12881
12894
12882
- // apply penalties
12883
- if(max_match_length > dry_allowed_length ) {
12884
- // calculate the penalty
12885
- float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length) ;
12895
+ // apply penalties
12896
+ for (const auto& pair : match_lengths ) {
12897
+ auto next_token = pair.first;
12898
+ auto match_length = pair.second ;
12886
12899
12887
- // apply the dry penalty
12888
- candidates->data[i].logit -= penalty;
12900
+ // if the match length is greater than our allowed length in config, we apply penalities
12901
+ if(match_length > dry_allowed_length) {
12902
+
12903
+ // find our next token in the candidates->data
12904
+ size_t i = 0;
12905
+ for (; i < candidates->size; ++i) {
12906
+ if (candidates->data[i].id == next_token) {
12907
+ // calculate the penalty
12908
+ float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
12909
+
12910
+ // apply the dry penalty
12911
+ candidates->data[i].logit -= penalty;
12912
+ break;
12913
+ }
12914
+ }
12889
12915
}
12890
12916
}
12891
12917
}
0 commit comments