@@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
13233
13233
}
13234
13234
}
13235
13235
13236
+ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
13237
+ // sanity check
13238
+ GGML_ASSERT(last_tokens_size > 0);
13239
+
13240
+ // get the last token
13241
+ auto last_token = last_tokens[last_tokens_size - 1];
13242
+
13243
+ // if last token is part of the sequence breakers, skip whole sampler
13244
+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
13245
+ return;
13246
+ }
13247
+
13248
+ // create an unordered map of "next tokens" <-> max match length
13249
+ std::unordered_map<llama_token, size_t> match_lengths;
13250
+
13251
+ // loop through each previous token (exclude the last token)
13252
+ for (size_t i = 0; i < last_tokens_size - 1; ++i) {
13253
+ // skip if the compare token if it's not the same as the last token
13254
+ if(last_tokens[i] != last_token) {
13255
+ continue;
13256
+ }
13257
+
13258
+ // get the next token (i + 1 is always less than last_tokens_size)
13259
+ auto next_token = last_tokens[i + 1];
13260
+
13261
+ // try to extend the match backwards (match length starts a 1 because last token is already matched)
13262
+ size_t match_length = 1;
13263
+
13264
+ // loop through the previous tokens
13265
+ for(;; match_length++) {
13266
+ // if we have reached the start of our last tokens, break
13267
+ if(i < match_length) break;
13268
+
13269
+ // compare token starts at our prev index, going backwards by match length
13270
+ auto compare_token = last_tokens[i - match_length];
13271
+
13272
+ // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
13273
+ auto head_token = last_tokens[last_tokens_size - 1 - match_length];
13274
+
13275
+ // if compare token is part of the sequence breakers, break out of the match
13276
+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13277
+ break;
13278
+
13279
+ // break out of the match if any tokens don't match
13280
+ if(compare_token != head_token)
13281
+ break;
13282
+ }
13283
+
13284
+ // Check if the next token exists in the map
13285
+ auto it = match_lengths.find(next_token);
13286
+
13287
+ if (it == match_lengths.end()) {
13288
+ // Key does not exist, insert the new value
13289
+ match_lengths[next_token] = match_length;
13290
+ } else {
13291
+ // Key exists, update it with the max of the new value or the existing value
13292
+ it->second = std::max(it->second, match_length);
13293
+ }
13294
+ }
13295
+
13296
+ // apply penalties
13297
+ for (const auto& pair : match_lengths) {
13298
+ auto next_token = pair.first;
13299
+ auto match_length = pair.second;
13300
+
13301
+ // if the match length is greater than our allowed length in config, we apply penalities
13302
+ if(match_length > dry_allowed_length) {
13303
+
13304
+ // find our next token in the candidates->data
13305
+ size_t i = 0;
13306
+ for (; i < candidates->size; ++i) {
13307
+ if (candidates->data[i].id == next_token) {
13308
+ // calculate the penalty
13309
+ float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
13310
+
13311
+ // apply the dry penalty
13312
+ candidates->data[i].logit -= penalty;
13313
+ break;
13314
+ }
13315
+ }
13316
+ }
13317
+ }
13318
+ }
13319
+
13236
13320
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
13237
13321
if (z >= 1.0f || candidates->size <= 2) {
13238
13322
return;
0 commit comments