Skip to content

Use bucket sort for token logits #5101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

JohannesGaessler
Copy link
Collaborator

PR #5085 attempts to speed up token sampling for cases with large --top-k by using std::nth_element and then std::sort. This PR instead exploits the limited range of token logits to implement bucket sort. There are 100 equal-sized buckets in the logit range [-10, 10]. Tokens are first distributed to these buckets. Then only those buckets relevant for --top-k are actually sorted. This is also faster for --top-k 32000 because the runtime for comparison-based sorting algorithms increases more than linearly with the number of elements. I am assuming that the top tokens beyond --top-k can be arbitrarily changed without affecting program correctness.

For LLaMA 2 7b q8_0 I get on my system with an RTX 3090 and a Ryzen 3700X:

--top-k t/s master t/s nth_element Speedup nth_element t/s bucket sort Speedup bucket_sort
1000 1991 1973 0.99 2213 1.11
2000 1396 1375 0.98 2065 1.48
4000 893 1478 1.66 1821 2.04
8000 574 1123 1.96 1416 2.47
16000 389 778 2.00 997 2.56
31780 309 492 1.59 639 2.07
32000 492 485 0.99 637 1.29

bucket_sort

@ikawrakow
Copy link
Contributor

I made a stand-alone benchmark (attached). It reads logits from a file (previously generated with perplexity --kl-divergence-base and tests std::sort/std::partial_sort against this bucket sort. On my system the standard library is faster than this bucket sort by 2X..20X, depending on top_k. Which is not surprising, given the few hundred allocations that happen in the bucket sort.

For instance

./perplexity -m models/Mistral-7B/ggml-model-f16.gguf -f tests/wiki.test.raw --kl-divergence-base m7.logits

g++ -O3 -o test_sorting test_sorting.cpp

./test_sorting m7.logits 100 50
std::sort:
  <time> = 21.241 us
  top_1  = 0.700492
  top_k  = 0.000402343
Bucket:
  <time> = 437.53 us
  top_1  = 0.700492
  top_k  = 0.000402343

./test_sorting some_logits 100 500
std::sort:
  <time> = 75.9663 us
  top_1  = 0.700492
  top_k  = 1.52254e-05
Bucket:
  <time> = 504.853 us
  top_1  = 0.700492
  top_k  = 1.52254e-05

./test_sorting some_logits 100 5000
std::sort:
  <time> = 236.347 us
  top_1  = 0.700492
  top_k  = 1.92905e-07
Bucket:
  <time> = 664.981 us
  top_1  = 0.700492
  top_k  = 1.92905e-07

test_sorting.cpp.gz

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jan 24, 2024

With the benchmark code you posted I also get worse performance than with the command I used to test my implementation, e.g. ./main --model models/nvme/${model_name}-${quantization}.gguf -ngl 99 --ctx-size 4096 --ignore-eos --n-predict 256 --seed 1337 --top-k 32000. But I think the issue is that the benchmark is simply not representative in terms of logit value distribution since you do some transformations which - for the calculation of KL divergence - are negligible. I added a simple debug print to inspect how the values are being distributed to buckets with --top-k 32000 and with main you end up with buckets holding up to ~3% of the values while with the benchmark a single bucket holds ~80% of the values.

main
ib=99 buckets[ib].size()=46
ib=98 buckets[ib].size()=4
ib=97 buckets[ib].size()=5
ib=96 buckets[ib].size()=4
ib=95 buckets[ib].size()=8
ib=94 buckets[ib].size()=6
ib=93 buckets[ib].size()=9
ib=92 buckets[ib].size()=10
ib=91 buckets[ib].size()=14
ib=90 buckets[ib].size()=17
ib=89 buckets[ib].size()=17
ib=88 buckets[ib].size()=19
ib=87 buckets[ib].size()=22
ib=86 buckets[ib].size()=31
ib=85 buckets[ib].size()=20
ib=84 buckets[ib].size()=27
ib=83 buckets[ib].size()=35
ib=82 buckets[ib].size()=36
ib=81 buckets[ib].size()=32
ib=80 buckets[ib].size()=44
ib=79 buckets[ib].size()=48
ib=78 buckets[ib].size()=60
ib=77 buckets[ib].size()=68
ib=76 buckets[ib].size()=56
ib=75 buckets[ib].size()=93
ib=74 buckets[ib].size()=91
ib=73 buckets[ib].size()=95
ib=72 buckets[ib].size()=110
ib=71 buckets[ib].size()=130
ib=70 buckets[ib].size()=158
ib=69 buckets[ib].size()=157
ib=68 buckets[ib].size()=162
ib=67 buckets[ib].size()=207
ib=66 buckets[ib].size()=194
ib=65 buckets[ib].size()=269
ib=64 buckets[ib].size()=302
ib=63 buckets[ib].size()=329
ib=62 buckets[ib].size()=371
ib=61 buckets[ib].size()=363
ib=60 buckets[ib].size()=445
ib=59 buckets[ib].size()=452
ib=58 buckets[ib].size()=518
ib=57 buckets[ib].size()=556
ib=56 buckets[ib].size()=605
ib=55 buckets[ib].size()=672
ib=54 buckets[ib].size()=738
ib=53 buckets[ib].size()=838
ib=52 buckets[ib].size()=834
ib=51 buckets[ib].size()=897
ib=50 buckets[ib].size()=994
ib=49 buckets[ib].size()=962
ib=48 buckets[ib].size()=996
ib=47 buckets[ib].size()=1075
ib=46 buckets[ib].size()=1068
ib=45 buckets[ib].size()=1110
ib=44 buckets[ib].size()=1120
ib=43 buckets[ib].size()=1091
ib=42 buckets[ib].size()=1136
ib=41 buckets[ib].size()=1126
ib=40 buckets[ib].size()=1126
ib=39 buckets[ib].size()=950
ib=38 buckets[ib].size()=1046
ib=37 buckets[ib].size()=921
ib=36 buckets[ib].size()=825
ib=35 buckets[ib].size()=824
ib=34 buckets[ib].size()=753
ib=33 buckets[ib].size()=621
ib=32 buckets[ib].size()=679
ib=31 buckets[ib].size()=552
ib=30 buckets[ib].size()=456
ib=29 buckets[ib].size()=414
ib=28 buckets[ib].size()=324
ib=27 buckets[ib].size()=269
ib=26 buckets[ib].size()=258
ib=25 buckets[ib].size()=219
ib=24 buckets[ib].size()=184
ib=23 buckets[ib].size()=131
ib=22 buckets[ib].size()=94
ib=21 buckets[ib].size()=100
ib=20 buckets[ib].size()=81
ib=19 buckets[ib].size()=64
ib=18 buckets[ib].size()=49
ib=17 buckets[ib].size()=38
ib=16 buckets[ib].size()=30
ib=15 buckets[ib].size()=24
ib=14 buckets[ib].size()=18
ib=13 buckets[ib].size()=10
ib=12 buckets[ib].size()=12
ib=11 buckets[ib].size()=6
ib=10 buckets[ib].size()=5
ib=9 buckets[ib].size()=3
ib=8 buckets[ib].size()=3
ib=7 buckets[ib].size()=2
ib=6 buckets[ib].size()=3
ib=5 buckets[ib].size()=1
ib=4 buckets[ib].size()=0
ib=3 buckets[ib].size()=1
ib=2 buckets[ib].size()=0
ib=1 buckets[ib].size()=0
ib=0 buckets[ib].size()=2
test_sorting
ib=99 buckets[ib].size()=0
ib=98 buckets[ib].size()=0
ib=97 buckets[ib].size()=0
ib=96 buckets[ib].size()=1
ib=95 buckets[ib].size()=0
ib=94 buckets[ib].size()=0
ib=93 buckets[ib].size()=0
ib=92 buckets[ib].size()=0
ib=91 buckets[ib].size()=1
ib=90 buckets[ib].size()=1
ib=89 buckets[ib].size()=0
ib=88 buckets[ib].size()=0
ib=87 buckets[ib].size()=0
ib=86 buckets[ib].size()=0
ib=85 buckets[ib].size()=0
ib=84 buckets[ib].size()=0
ib=83 buckets[ib].size()=1
ib=82 buckets[ib].size()=1
ib=81 buckets[ib].size()=1
ib=80 buckets[ib].size()=0
ib=79 buckets[ib].size()=0
ib=78 buckets[ib].size()=1
ib=77 buckets[ib].size()=2
ib=76 buckets[ib].size()=0
ib=75 buckets[ib].size()=1
ib=74 buckets[ib].size()=0
ib=73 buckets[ib].size()=2
ib=72 buckets[ib].size()=2
ib=71 buckets[ib].size()=1
ib=70 buckets[ib].size()=3
ib=69 buckets[ib].size()=2
ib=68 buckets[ib].size()=1
ib=67 buckets[ib].size()=2
ib=66 buckets[ib].size()=3
ib=65 buckets[ib].size()=4
ib=64 buckets[ib].size()=4
ib=63 buckets[ib].size()=6
ib=62 buckets[ib].size()=8
ib=61 buckets[ib].size()=6
ib=60 buckets[ib].size()=5
ib=59 buckets[ib].size()=6
ib=58 buckets[ib].size()=4
ib=57 buckets[ib].size()=3
ib=56 buckets[ib].size()=6
ib=55 buckets[ib].size()=11
ib=54 buckets[ib].size()=15
ib=53 buckets[ib].size()=14
ib=52 buckets[ib].size()=10
ib=51 buckets[ib].size()=20
ib=50 buckets[ib].size()=20
ib=49 buckets[ib].size()=29
ib=48 buckets[ib].size()=26
ib=47 buckets[ib].size()=25
ib=46 buckets[ib].size()=27
ib=45 buckets[ib].size()=33
ib=44 buckets[ib].size()=49
ib=43 buckets[ib].size()=48
ib=42 buckets[ib].size()=51
ib=41 buckets[ib].size()=52
ib=40 buckets[ib].size()=73
ib=39 buckets[ib].size()=68
ib=38 buckets[ib].size()=92
ib=37 buckets[ib].size()=100
ib=36 buckets[ib].size()=102
ib=35 buckets[ib].size()=106
ib=34 buckets[ib].size()=119
ib=33 buckets[ib].size()=147
ib=32 buckets[ib].size()=150
ib=31 buckets[ib].size()=191
ib=30 buckets[ib].size()=200
ib=29 buckets[ib].size()=201
ib=28 buckets[ib].size()=241
ib=27 buckets[ib].size()=285
ib=26 buckets[ib].size()=282
ib=25 buckets[ib].size()=336
ib=24 buckets[ib].size()=345
ib=23 buckets[ib].size()=378
ib=22 buckets[ib].size()=409
ib=21 buckets[ib].size()=446
ib=20 buckets[ib].size()=468
ib=19 buckets[ib].size()=512
ib=18 buckets[ib].size()=597
ib=17 buckets[ib].size()=557
ib=16 buckets[ib].size()=25087

Buckets below ib=16 are empty.

@ikawrakow
Copy link
Contributor

With the benchmark code you posted I also get worse performance than with the command I used to test my implementation, e.g. ./main --model models/nvme/${model_name}-${quantization}.gguf -ngl 99 --ctx-size 4096 --ignore-eos --n-predict 256 --seed 1337 --top-k 32000

So, basically, empty prompt. When the prompt is empty, then the probabilities (or logits) are indeed distributed broadly. But when the context grows, then we get more and more into the situation where only a few tokens have a non-negligible probability, so most tokens end up in the first or first few buckets. No, I haven't done anything special to the logits other than to reduce precision to 16 bits. These are the kind of logits you get from a context of 512.

Oh, and I don't know how you are getting up to ~2000 t/s with this command as per your table.

@JohannesGaessler
Copy link
Collaborator Author

Oh, and I don't know how you are getting up to ~2000 t/s with this command as per your table.

For the table I ran ./main --model models/nvme/${model_name}-${quantization}.gguf -ngl 99 --ctx-size 4096 --ignore-eos --n-predict 256 --seed 1337 --top-k <TOP_K_VALUE> and simply wrote down the sampling t/s reported by llama.cpp.

So, basically, empty prompt. When the prompt is empty, then the probabilities (or logits) are indeed distributed broadly. But when the context grows, then we get more and more into the situation where only a few tokens have a non-negligible probability, so most tokens end up in the first or first few buckets. [...] These are the kind of logits you get from a context of 512.

I ran ./main --model models/nvme/${model_name}-${quantization}.gguf -ngl 99 --ctx-size 4096 --ignore-eos --n-predict 4096 --seed 1337 --top-k 32000 and looked at how the logits are distributed for the last token:

ib=99 buckets[ib].size()=189
ib=98 buckets[ib].size()=30
ib=97 buckets[ib].size()=25
ib=96 buckets[ib].size()=37
ib=95 buckets[ib].size()=33
ib=94 buckets[ib].size()=40
ib=93 buckets[ib].size()=65
ib=92 buckets[ib].size()=67
ib=91 buckets[ib].size()=71
ib=90 buckets[ib].size()=87
ib=89 buckets[ib].size()=111
ib=88 buckets[ib].size()=106
ib=87 buckets[ib].size()=127
ib=86 buckets[ib].size()=143
ib=85 buckets[ib].size()=163
ib=84 buckets[ib].size()=187
ib=83 buckets[ib].size()=223
ib=82 buckets[ib].size()=256
ib=81 buckets[ib].size()=238
ib=80 buckets[ib].size()=260
ib=79 buckets[ib].size()=277
ib=78 buckets[ib].size()=314
ib=77 buckets[ib].size()=361
ib=76 buckets[ib].size()=354
ib=75 buckets[ib].size()=395
ib=74 buckets[ib].size()=405
ib=73 buckets[ib].size()=435
ib=72 buckets[ib].size()=463
ib=71 buckets[ib].size()=507
ib=70 buckets[ib].size()=572
ib=69 buckets[ib].size()=609
ib=68 buckets[ib].size()=659
ib=67 buckets[ib].size()=769
ib=66 buckets[ib].size()=794
ib=65 buckets[ib].size()=852
ib=64 buckets[ib].size()=990
ib=63 buckets[ib].size()=1047
ib=62 buckets[ib].size()=1117
ib=61 buckets[ib].size()=1198
ib=60 buckets[ib].size()=1263
ib=59 buckets[ib].size()=1224
ib=58 buckets[ib].size()=1276
ib=57 buckets[ib].size()=1348
ib=56 buckets[ib].size()=1292
ib=55 buckets[ib].size()=1312
ib=54 buckets[ib].size()=1211
ib=53 buckets[ib].size()=1081
ib=52 buckets[ib].size()=1059
ib=51 buckets[ib].size()=981
ib=50 buckets[ib].size()=889
ib=49 buckets[ib].size()=885
ib=48 buckets[ib].size()=669
ib=47 buckets[ib].size()=608
ib=46 buckets[ib].size()=480
ib=45 buckets[ib].size()=410
ib=44 buckets[ib].size()=330
ib=43 buckets[ib].size()=260
ib=42 buckets[ib].size()=205
ib=41 buckets[ib].size()=159
ib=40 buckets[ib].size()=117
ib=39 buckets[ib].size()=89
ib=38 buckets[ib].size()=71
ib=37 buckets[ib].size()=47
ib=36 buckets[ib].size()=43
ib=35 buckets[ib].size()=30
ib=34 buckets[ib].size()=23
ib=33 buckets[ib].size()=14
ib=32 buckets[ib].size()=8
ib=31 buckets[ib].size()=11
ib=30 buckets[ib].size()=9
ib=29 buckets[ib].size()=10
ib=28 buckets[ib].size()=4
ib=27 buckets[ib].size()=1
ib=26 buckets[ib].size()=2
ib=25 buckets[ib].size()=0
ib=24 buckets[ib].size()=0
ib=23 buckets[ib].size()=0
ib=22 buckets[ib].size()=1
ib=21 buckets[ib].size()=0
ib=20 buckets[ib].size()=0
ib=19 buckets[ib].size()=0
ib=18 buckets[ib].size()=1
ib=17 buckets[ib].size()=0
ib=16 buckets[ib].size()=0
ib=15 buckets[ib].size()=0
ib=14 buckets[ib].size()=0
ib=13 buckets[ib].size()=0
ib=12 buckets[ib].size()=0
ib=11 buckets[ib].size()=0
ib=10 buckets[ib].size()=0
ib=9 buckets[ib].size()=0
ib=8 buckets[ib].size()=0
ib=7 buckets[ib].size()=0
ib=6 buckets[ib].size()=0
ib=5 buckets[ib].size()=0
ib=4 buckets[ib].size()=0
ib=3 buckets[ib].size()=0
ib=2 buckets[ib].size()=0
ib=1 buckets[ib].size()=0
ib=0 buckets[ib].size()=1

The maximum number of tokens per bucket is still only 1348 and llama.cpp reported a sampling speed of 639 t/s which is essentially the same as with 256 tokens.

No, I haven't done anything special to the logits other than to reduce precision to 16 bits.

In perplexity.cpp lines 129 and 144 you truncate the logits that you write to the file. You said here:

So, to reduce the size of the data being stored in the base run, I store them as uint16_t (the log-probabilities for wiki.test.run would be 20 GB, we have a size of 10 GB that way). The minimum logit can be very small, so I have decided to limit the probability range to e^(-16) ~ 1e-7. This slightly improves the precision of the 16-bit values being stored.

This causes the majority of tokens to receive the exact same logit. For KL divergence this does not matter. For bucket sort it does because it is only faster if you have a way of distributing values to the buckets somewhat evenly. With a somewhat even distribution you can then apply standard sorting algorithms to each bucket and because comparison-based sorting algorithms always scale with $O(n \log n)$ or worse the sorting is faster than if you were to sort all values at once. But if you have a degenerate distribution where (almost) all values end up in the same bucket then you get overhead from copying the values to buckets but sorting the buckets is not any faster than sorting all values at once.

@ikawrakow
Copy link
Contributor

ikawrakow commented Jan 24, 2024

In which buckets the tokens end up is a matter of context. You see in the example you posted above that many logits ended up in bucket 16. As nothing was done to the logits above bucket 16, this means that without the processing done in perplexity when saving the logit data, the 25k+ tokens that are now all in bucket 16 would have been distributed in just 16 buckets rather than 60+ buckets as it happens in your usage.

But if for the sake of argument we assume for a moment that the test case I made is not realistic, the PR still lowers performance for more normal usage with top_k < 1000. On my system with empty prompt and 256 generated tokens (i.e., using the exact same command as you did) I get

top_k Master t/s PR t/s Speedup
40 8648 3034 0.351
100 7758 3076 0.396
200 6796 3039 0.447
500 4847 2905 0.599
1000 3353 2872 0.856
2000 2133 2592 1.215

Where the crossover to the STL having better performance than your bucket sort very likely dependent on many details (e.g., compiler, compiler version, CPU, context). So, I think, you definitely need a check there for k > some_constant where the bucket sort is being applied.

@JohannesGaessler
Copy link
Collaborator Author

In which buckets the tokens end up is a matter of context. You see in the example you posted above that many logits ended up in bucket 16. As nothing was done to the logits above bucket 16, this means that without the processing done in perplexity when saving the logit data, the 25k+ tokens that are now all in bucket 16 would have been distributed in just 16 buckets rather than 60+ buckets as it happens in your usage.

Note that the scaling applied in perplexity.cpp is relative to the maximum logit value while the buckets are defined based on the absolute logit values. It is technically true that the exact distribution depends on context but as far as I can tell the logits that come out of the LLaMA models are almost all centered around 0 with a few tokens that have logits around ~25 (which are the ones that actually make sense given the context). So the truncation in perplexity.cpp causes all tokens with logits below ~9 to end up in the same bucket. These are essentially all garbage tokens and can be ignored for practical purposes but if you want to implement tok_k sampling without approximations (or first filtering by min_p) they need to be considered.

Where the crossover to the STL having better performance than your bucket sort very likely dependent on many details (e.g., compiler, compiler version, CPU, context). So, I think, you definitely need a check there for k > some_constant where the bucket sort is being applied.

A valid point.

@cebtenzzre
Copy link
Collaborator

Obsoleted by #5109 (merged)

@cebtenzzre cebtenzzre closed this Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants