diff --git a/llama.cpp b/llama.cpp index 582e82260ea85..494b7a4e33a1d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8001,11 +8001,32 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - if (k == (int) candidates->size) { - std::sort(candidates->data, candidates->data + candidates->size, comp); - } else { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + constexpr int nbuckets = 100; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + + std::vector buckets[nbuckets]; + + for (size_t i = 0; i < candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(ib, 0); + ib = std::min(ib, nbuckets-1); + buckets[ib].push_back(candidates->data[i]); + } + + int nsorted = 0; + for (int ib = nbuckets-1; ib >= 0; --ib) { + std::sort(buckets[ib].begin(), buckets[ib].end(), comp); + memcpy(candidates->data + nsorted, buckets[ib].data(), buckets[ib].size()*sizeof(llama_token_data)); + + nsorted += buckets[ib].size(); + + if (nsorted >= k) { + break; + } } + candidates->sorted = true; } candidates->size = k;