-
Notifications
You must be signed in to change notification settings - Fork 12k
CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case #12183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case #12183
Conversation
be39646
to
76881ac
Compare
@JohannesGaessler @slaren @ggerganov Gentle reminder to review this once you have free cycles. This PR improves gen phase throughput significantly for large seq lengths. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some tests on a V100 - not sure how relevant this GPU is, but I don't have any other NVIDIA GPU atm. It's strange that on master
, enabling FA results in overall worse performance for the prompt processing, and also sometimes for the text generation (3B Qwen). Not sure if this is expected since I haven't ran CUDA tests in a long time.
Master
./bin/llama-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
model | size | params | backend | ngl | fa | test | t/s |
---|---|---|---|---|---|---|---|
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp512 | 5637.74 ± 9.82 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp4096 | 4729.83 ± 20.11 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | tg128 | 143.32 ± 0.39 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp16384+tg128 | 2058.38 ± 5.73 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp512 | 5323.62 ± 255.81 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp4096 | 4064.98 ± 13.52 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | tg128 | 134.82 ± 7.17 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp16384+tg128 | 1614.09 ± 58.57 |
build: 57b6abf (4836)
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
16384 | 128 | 1 | 16512 | 6.390 | 2564.14 | 1.746 | 73.33 | 8.135 | 2029.68 |
16384 | 128 | 1 | 16512 | 6.379 | 2568.26 | 1.749 | 73.18 | 8.129 | 2031.36 |
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
16384 | 128 | 1 | 16512 | 8.250 | 1985.95 | 2.076 | 61.65 | 10.326 | 1599.05 |
16384 | 128 | 1 | 16512 | 8.210 | 1995.55 | 2.080 | 61.54 | 10.290 | 1604.64 |
./bin/llama-bench -m Qwen2.5-Coder-7B-Instruct-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
model | size | params | backend | ngl | fa | test | t/s |
---|---|---|---|---|---|---|---|
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp512 | 2978.12 ± 3.10 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp4096 | 2610.31 ± 1.46 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | tg128 | 117.44 ± 0.42 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp16384+tg128 | 1349.76 ± 0.37 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp512 | 2881.94 ± 3.95 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp4096 | 2248.26 ± 123.64 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | tg128 | 121.96 ± 0.45 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp16384+tg128 | 1163.20 ± 0.27 |
build: 57b6abf (4836)
PR
./bin/llama-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
model | size | params | backend | ngl | fa | test | t/s |
---|---|---|---|---|---|---|---|
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp512 | 5615.01 ± 16.51 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp4096 | 4726.22 ± 18.17 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | tg128 | 143.53 ± 0.33 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 0 | pp16384+tg128 | 2064.51 ± 0.91 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp512 | 5448.27 ± 13.05 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp4096 | 4021.94 ± 1.05 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | tg128 | 148.28 ± 0.73 |
qwen2 3B Q4_0 | 1.70 GiB | 3.09 B | CUDA | 99 | 1 | pp16384+tg128 | 1747.92 ± 6.05 |
build: 741b729 (4825)
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
16384 | 128 | 1 | 16512 | 6.393 | 2562.87 | 1.755 | 72.95 | 8.148 | 2026.62 |
16384 | 128 | 1 | 16512 | 6.379 | 2568.56 | 1.751 | 73.09 | 8.130 | 2031.00 |
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
16384 | 128 | 1 | 16512 | 8.249 | 1986.18 | 1.377 | 92.97 | 9.626 | 1715.40 |
16384 | 128 | 1 | 16512 | 8.192 | 1999.90 | 1.378 | 92.90 | 9.570 | 1725.35 |
./bin/llama-bench -m Qwen2.5-Coder-7B-Instruct-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
model | size | params | backend | ngl | fa | test | t/s |
---|---|---|---|---|---|---|---|
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp512 | 2973.22 ± 3.84 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp4096 | 2609.94 ± 0.82 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | tg128 | 116.65 ± 0.65 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 0 | pp16384+tg128 | 1344.49 ± 9.49 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp512 | 2883.97 ± 2.49 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp4096 | 2302.82 ± 0.82 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | tg128 | 121.97 ± 0.10 |
qwen2 7B Q4_0 | 4.13 GiB | 7.62 B | CUDA | 99 | 1 | pp16384+tg128 | 1175.46 ± 0.30 |
build: 741b729 (4825)
So it seems the PR improves the TG speed in these cases.
Numbers I get on RTX 4090 look different and flash attention shows clear benefit in both prompt processing and text generation. Master
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: 4299404 (4839)
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: 4299404 (4839) PR
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: 741b729 (4825)
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16
main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: 741b729 (4825) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this is a good find. In terms of correctness this looks good to me. However, I think parallel_blocks
should be changed to a runtime parameter. I originally made it a template parameter because it was only ever set in conjunction with cols_per_block
but since it's only used outside of the loop over K->ne[1]
it should be fine. The advantage would be that this would simplify your implementation and avoid the need for an increase in template specializations that would need to be compiled. I'll make a PR for this.
Consider also taking a look at the FP16 kernel; it is currently not really being used since the default FA precision is FP32 but it would likely also benefit.
Thanks @JohannesGaessler . I agree changing I'll wait for your PR and then rebase my change on top of it. I'll also update the fp16 kernel once your PR is merged. As currently fp16 kernel is not being used, I didn't want to pay the cost of template instantiation for that kernel. |
I pushed a patch to With the code I pushed there would very likely be a performance regression for V100s and AMD GPUs with tensor cores; we should check the performance prior to merging. @IMbackK your assistance would be appreciated when the time comes. |
@JohannesGaessler sure I would be happy to, just infirm me when it is ready. |
a8a7175
to
d83b0d0
Compare
Thanks @JohannesGaessler for this change. I have now rebased my change on top of your branch. I have confirmed that performance results are same as what I reported in PR description on RTX 4090. |
@JohannesGaessler What are the next steps here? Please let me know if I can help with some of the testing. Are there any specific GPU architectures that you would like to see tested? |
The next step is testing. Sorry, I forgot. I'm currently busy but I should get to it by Friday at the latest. |
Interestingly there does appear to be a small but measurable TG improvement on ROCm gfx11
Might be larger if there's a way to measure with context prefilled to 30k or so. |
@Beinsezii thanks for testing. Yes, the improvements are expected more if we prefill context to large sequence. The way I test this is how I have described in PR description:
|
Ah, I didn't realize that's what In that case, for llama-8b-q4k my XTX is at |
Sorry for the late reply, I'm seeing bad performance on Pascal. Against my expectation the root cause seems to be specifically the change where I made the number of parallel blocks a runtime parameter. On Ada Lovelace this is completely negligible but on Pascal the kernel runtime for some reason goes up by ~50%. Do you have an idea what could be causing this? |
I debugged the performance issues on Pascal and was able to fix them by changing the way the grid indices are mapped to tensor slices; this seems to improve Pascal performance in general. I pushed a new version to |
d83b0d0
to
b6e067b
Compare
Thanks @JohannesGaessler . I have now rebased my changes on top of |
Did some more testing on RTX 4090. I see around 10% improvement on Llama-3b and Qwen-1.5b model for high seq length.
|
Adds the following optimizations to the CUDA flash decoding code: - Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 This results in upto 15% perf improvement in gen phase throughput for large seq lengths. Issue: ggml-org#12182
b6e067b
to
aa5aa01
Compare
Rebased the changes on top of master as there were some conflicts. Also benchmarked performance on GTX 1080. This is what I get:
There are no regressions and we see a decent gain with Qwen 1.5B model. @JohannesGaessler do you think we can merge this change? |
There are performance regressions, just not for the GPUs you've tested. Your implementation does not consider tail effects so on the last wave the GPU utilization can be bad depending on the SM count vs. the input size. This is noticeable with an RTX 3090 or a P40. Since I have those GPUs for testing I'm working on a fix myself and will provide it soon. |
I can try to get access to a 3090. Can you please paste the exact |
I am testing the performance with commands like this: export model_name=llama_3-8b && export quantization=q4_0
./bench --model models/opt/${model_name}-${quantization}.gguf -r 1 -n 0 -fa 1 -p 16384 -ub 1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384 -o sql|sqlite3 llama-bench.sqlite
python3 scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b 3d652bfddfba09022525067e672c3c145c074649 -c 2a38d5c8c4843c413bb65d3724b3256c9a4973df There are GPUs for which the code path with batch sizes > 1 changes but for an RTX 3090 only batch size 1 is relevant. But as I said, I will push something soon. I would suggest you wait until then and see if you can improve upon that version. |
Sure, thanks. I will wait for your changes to be available. |
I pushed a commit to your branch which iterates over potential higher values for
So on the GPUs I've tested the performance for is now consistently better than on master. @ggerganov if you could check the performance on your V100 it would be appreciated. @IMbackK please check the performance for an AMD GPU using the WMMA kernels. In both cases both the performance for batch size 1 and batch sizes >> 1 may have changed. Also since I have now written a non-negligible percentage of the code in this PR I would appreciate it if someone else could add a review. |
Small 7900 XTX benches on new head using
Seems like the additional regression fixes have almost doubled my effective increase since the last bench #12183 (comment) it's kind of funny because of the difference in util this PR changes the tone of my coil whine |
@JohannesGaessler I see no change or slight improvements across the board on CDNA with WMMA PR:
Master:
|
With V100 it seems there is not much difference - is this within expectations? GGML_CUDA=1 ./scripts/compare-commits.sh master pr/12183 -m models-mnt/open-llama/7B-v2/ggml-model-q4_0.gguf -m models-mnt/open-llama/7B-v2/ggml-model-q8_0.gguf -m models-mnt/open-llama/7B-v2/ggml-model-q4_k.gguf -fa 1 -n 0 -p 4096 -ub 1,2,4,8,16,32,64,128,256,512 -r 1 -ngl 99
|
If there is no change that's fine, I just want to make sure there is no regression. +-1% can just be run-to-run variation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my end this PR would now be good to merge.
|
ditto rdna3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im also happy with the code changes.
This PR adds the following optimizations to the CUDA flash decoding code:
cudaOccupancyMaxActiveBlocksPerMultiprocessor
API. Use this value to determine the optimalparallel_blocks
value.As flash decoding is dominant only with large seq lengths, e2e perf improvements are seen only with larger seq lengths. With 16000 seq length, we see an improvement of up to 14% in gen phase throughput. The gains are expected to continue to scale with larger seq lengths. Details available below.
Issue: #12182
Performance
As this change impacts gen phase throughput for large seq lengths, the benchmarking was done in following way:
llama-bench.exe -m <model> -p 10,100,1000,2000,4000,8000,16000 -n 0 -pg 10,200 -pg 100,200 -pg 1000,200 -pg 2000,200 -pg 4000,200 -pg 8000,200 -pg 16000,200 -o csv -fa 1
avg_ns
value for a. the prompt phase and b. the prompt + gen phase combined. Subtracting #a) from #b) gives time spent in only gen phase which is shown asgen_avg_ms
in below table.I get the following numbers on RTX 4090 on master (prefixed with master) and this pr (prefixed with pr):
I have read the contributing guidelines
Self-reported review complexity: