-
Notifications
You must be signed in to change notification settings - Fork 12k
Use bucket sort for token logits #5101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use bucket sort for token logits #5101
Conversation
I made a stand-alone benchmark (attached). It reads logits from a file (previously generated with For instance
|
With the benchmark code you posted I also get worse performance than with the command I used to test my implementation, e.g. main
test_sorting
Buckets below |
So, basically, empty prompt. When the prompt is empty, then the probabilities (or logits) are indeed distributed broadly. But when the context grows, then we get more and more into the situation where only a few tokens have a non-negligible probability, so most tokens end up in the first or first few buckets. No, I haven't done anything special to the logits other than to reduce precision to 16 bits. These are the kind of logits you get from a context of 512. Oh, and I don't know how you are getting up to ~2000 t/s with this command as per your table. |
For the table I ran
I ran
The maximum number of tokens per bucket is still only 1348 and llama.cpp reported a sampling speed of 639 t/s which is essentially the same as with 256 tokens.
In
This causes the majority of tokens to receive the exact same logit. For KL divergence this does not matter. For bucket sort it does because it is only faster if you have a way of distributing values to the buckets somewhat evenly. With a somewhat even distribution you can then apply standard sorting algorithms to each bucket and because comparison-based sorting algorithms always scale with |
In which buckets the tokens end up is a matter of context. You see in the example you posted above that many logits ended up in bucket 16. As nothing was done to the logits above bucket 16, this means that without the processing done in But if for the sake of argument we assume for a moment that the test case I made is not realistic, the PR still lowers performance for more normal usage with
Where the crossover to the STL having better performance than your bucket sort very likely dependent on many details (e.g., compiler, compiler version, CPU, context). So, I think, you definitely need a check there for |
Note that the scaling applied in
A valid point. |
Obsoleted by #5109 (merged) |
PR #5085 attempts to speed up token sampling for cases with large --top-k by using
std::nth_element
and thenstd::sort
. This PR instead exploits the limited range of token logits to implement bucket sort. There are 100 equal-sized buckets in the logit range [-10, 10]. Tokens are first distributed to these buckets. Then only those buckets relevant for --top-k are actually sorted. This is also faster for --top-k 32000 because the runtime for comparison-based sorting algorithms increases more than linearly with the number of elements. I am assuming that the top tokens beyond --top-k can be arbitrarily changed without affecting program correctness.For LLaMA 2 7b q8_0 I get on my system with an RTX 3090 and a Ryzen 3700X: