Skip to content

CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case #12183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2025

Conversation

gaugarg-nv
Copy link
Contributor

@gaugarg-nv gaugarg-nv commented Mar 4, 2025

This PR adds the following optimizations to the CUDA flash decoding code:

  • Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value.
  • Prefer vector flash attention kernels over MMA kernel for BS=1 as with this change vector_mha kernel beats mma based mha kernels.

As flash decoding is dominant only with large seq lengths, e2e perf improvements are seen only with larger seq lengths. With 16000 seq length, we see an improvement of up to 14% in gen phase throughput. The gains are expected to continue to scale with larger seq lengths. Details available below.

Issue: #12182

Performance

As this change impacts gen phase throughput for large seq lengths, the benchmarking was done in following way:

  • Run llama-bench.exe -m <model> -p 10,100,1000,2000,4000,8000,16000 -n 0 -pg 10,200 -pg 100,200 -pg 1000,200 -pg 2000,200 -pg 4000,200 -pg 8000,200 -pg 16000,200 -o csv -fa 1
  • This provides avg_ns value for a. the prompt phase and b. the prompt + gen phase combined. Subtracting #a) from #b) gives time spent in only gen phase which is shown as gen_avg_ms in below table.

I get the following numbers on RTX 4090 on master (prefixed with master) and this pr (prefixed with pr):

Model n_prompt n_gen master_avg_ns master_gen_avg_ms pr_avg_ns pr_gen_avg_ms Speed-up
llama 3B Q4_K - Medium
10 0 8276500 8210860
100 0 11219420 11776440
1000 0 42959940 44420860
2000 0 86611920 86546600
4000 0 184733060 180640240
8000 0 398559780 398832300
16000 0 959179080 958560000
10 200 738011460 729.73 743034380 734.82 0.99
100 200 747347360 736.13 745493700 733.72 1.00
1000 200 827614040 784.65 813552720 769.13 1.02
2000 200 908938600 822.33 889483700 802.94 1.02
4000 200 1100082180 915.35 1040240060 859.60 1.06
8000 200 1489083780 1090.52 1373256720 974.42 1.12
16000 200 2371317640 1412.14 2192029980 1233.47 1.14
llama 8B Q4_K - Medium
10 0 11771200 12052960
100 0 17274700 16183500
1000 0 85221840 85714320
2000 0 175552220 175143680
4000 0 354803180 354138120
8000 0 758988000 759214320
16000 0 1724419420 1723164020
10 200 1371246100 1359.47 1333360260 1321.31 1.03
100 200 1374515580 1357.24 1340339720 1324.16 1.02
1000 200 1466855420 1381.63 1462192200 1376.48 1.00
2000 200 1584856320 1409.30 1577855500 1402.71 1.00
4000 200 1857644120 1502.84 1820829360 1466.69 1.02
8000 200 2421499080 1662.51 2357915780 1598.70 1.04
16000 200 3681208540 1956.79 3607611840 1884.45 1.04
qwen2 1.5B Q4_K - Medium
10 0 9674800 7476440
100 0 10231340 10174900
1000 0 31047760 31299640
2000 0 61244000 61835680
4000 0 125470380 125735260
8000 0 272910140 272990700
16000 0 654824360 645511720
10 200 581147620 571.47 561862220 554.39 1.03
100 200 583839320 573.61 554249620 544.07 1.05
1000 200 614533460 583.49 595459540 564.16 1.03
2000 200 667034580 605.79 636267320 574.43 1.05
4000 200 766202980 640.73 731192180 605.46 1.06
8000 200 989000400 716.09 932145720 659.16 1.09
16000 200 1508910500 854.09 1411803780 766.29 1.11
qwen2 7B Q4_K - Medium
10 0 11509000 10902840
100 0 15888700 15636600
1000 0 83561980 83783320
2000 0 170899980 170541560
4000 0 346676380 345789520
8000 0 735418880 734362580
16000 0 1658389120 1652635280
10 200 1284247020 1272.74 1277166900 1266.26 1.01
100 200 1284649740 1268.76 1285211980 1269.58 1.00
1000 200 1390807060 1307.25 1386056060 1302.27 1.00
2000 200 1519021700 1348.12 1493807840 1323.27 1.02
4000 200 1783544660 1436.87 1714694020 1368.90 1.05
8000 200 2322356100 1586.94 2202785980 1468.42 1.08
16000 200 3547261460 1888.87 3321806240 1669.17 1.13

I have read the contributing guidelines
Self-reported review complexity:

  • Low
  • Medium
  • High

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 4, 2025
@gaugarg-nv gaugarg-nv force-pushed the flash_decoding_improvement branch 2 times, most recently from be39646 to 76881ac Compare March 4, 2025 16:51
@gaugarg-nv gaugarg-nv changed the title CUDA: Improve flash decoding kernel occupancy for BS=1 case CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case Mar 5, 2025
@gaugarg-nv
Copy link
Contributor Author

@JohannesGaessler @slaren @ggerganov Gentle reminder to review this once you have free cycles. This PR improves gen phase throughput significantly for large seq lengths.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some tests on a V100 - not sure how relevant this GPU is, but I don't have any other NVIDIA GPU atm. It's strange that on master, enabling FA results in overall worse performance for the prompt processing, and also sometimes for the text generation (3B Qwen). Not sure if this is expected since I haven't ran CUDA tests in a long time.

Master

./bin/llama-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl fa test t/s
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp512 5637.74 ± 9.82
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp4096 4729.83 ± 20.11
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 tg128 143.32 ± 0.39
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp16384+tg128 2058.38 ± 5.73
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp512 5323.62 ± 255.81
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp4096 4064.98 ± 13.52
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 tg128 134.82 ± 7.17
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp16384+tg128 1614.09 ± 58.57

build: 57b6abf (4836)

./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 6.390 2564.14 1.746 73.33 8.135 2029.68
16384 128 1 16512 6.379 2568.26 1.749 73.18 8.129 2031.36
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 8.250 1985.95 2.076 61.65 10.326 1599.05
16384 128 1 16512 8.210 1995.55 2.080 61.54 10.290 1604.64
./bin/llama-bench -m Qwen2.5-Coder-7B-Instruct-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl fa test t/s
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp512 2978.12 ± 3.10
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp4096 2610.31 ± 1.46
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 tg128 117.44 ± 0.42
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp16384+tg128 1349.76 ± 0.37
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp512 2881.94 ± 3.95
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp4096 2248.26 ± 123.64
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 tg128 121.96 ± 0.45
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp16384+tg128 1163.20 ± 0.27

build: 57b6abf (4836)

PR

./bin/llama-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl fa test t/s
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp512 5615.01 ± 16.51
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp4096 4726.22 ± 18.17
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 tg128 143.53 ± 0.33
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 0 pp16384+tg128 2064.51 ± 0.91
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp512 5448.27 ± 13.05
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp4096 4021.94 ± 1.05
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 tg128 148.28 ± 0.73
qwen2 3B Q4_0 1.70 GiB 3.09 B CUDA 99 1 pp16384+tg128 1747.92 ± 6.05

build: 741b729 (4825)

./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 6.393 2562.87 1.755 72.95 8.148 2026.62
16384 128 1 16512 6.379 2568.56 1.751 73.09 8.130 2031.00
./bin/llama-batched-bench -m Qwen2.5-Coder-3B-Q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 6, n_threads_batch = 6

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 8.249 1986.18 1.377 92.97 9.626 1715.40
16384 128 1 16512 8.192 1999.90 1.378 92.90 9.570 1725.35
./bin/llama-bench -m Qwen2.5-Coder-7B-Instruct-Q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl fa test t/s
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp512 2973.22 ± 3.84
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp4096 2609.94 ± 0.82
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 tg128 116.65 ± 0.65
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 0 pp16384+tg128 1344.49 ± 9.49
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp512 2883.97 ± 2.49
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp4096 2302.82 ± 0.82
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 tg128 121.97 ± 0.10
qwen2 7B Q4_0 4.13 GiB 7.62 B CUDA 99 1 pp16384+tg128 1175.46 ± 0.30

build: 741b729 (4825)


So it seems the PR improves the TG speed in these cases.

@gaugarg-nv
Copy link
Contributor Author

Here are some tests on a V100 - not sure how relevant this GPU is, but I don't have any other NVIDIA GPU atm. It's strange that on master, enabling FA results in overall worse performance for the prompt processing, and also sometimes for the text generation (3B Qwen). Not sure if this is expected since I haven't ran CUDA tests in a long time.

Numbers I get on RTX 4090 look different and flash attention shows clear benefit in both prompt processing and text generation.

Master

llama-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl fa test t/s
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp512 18155.22 ± 259.17
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp4096 13807.21 ± 125.12
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 tg128 236.06 ± 3.23
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp16384+tg128 4771.10 ± 6.22
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp512 22103.91 ± 140.57
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp4096 21581.58 ± 71.62
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 tg128 238.21 ± 1.35
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp16384+tg128 9638.49 ± 61.71

build: 4299404 (4839)

llama-batched-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 2.702 6064.05 0.853 150.09 3.555 4645.20
16384 128 1 16512 2.685 6101.93 0.843 151.77 3.528 4679.68
llama-batched-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 1.048 15630.32 0.739 173.32 1.787 9241.50
16384 128 1 16512 1.034 15840.44 0.739 173.16 1.774 9310.22
llama-bench -m ..\qwen2.5-coder-7b-instruct-q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl fa test t/s
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp512 11633.26 ± 253.83
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp4096 8653.10 ± 110.32
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 tg128 159.24 ± 2.45
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp16384+tg128 3244.89 ± 32.46
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp512 12859.64 ± 25.69
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp4096 12379.55 ± 54.63
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 tg128 164.07 ± 0.27
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp16384+tg128 5912.40 ± 4.79

build: 4299404 (4839)

PR

llama-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl fa test t/s
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp512 18447.00 ± 67.31
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp4096 13804.29 ± 58.51
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 tg128 239.25 ± 0.31
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 0 pp16384+tg128 4747.66 ± 1.85
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp512 21607.42 ± 47.48
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp4096 21229.36 ± 100.87
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 tg128 249.60 ± 0.71
qwen2 3B Q4_0 1.86 GiB 3.40 B CUDA 99 1 pp16384+tg128 9649.77 ± 50.89

build: 741b729 (4825)

llama-batched-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 2.710 6045.87 0.849 150.79 3.559 4639.74
16384 128 1 16512 2.696 6077.46 0.844 151.59 3.540 4664.10
llama-batched-bench -m ..\qwen2.5-coder-3b-instruct-q4_0.gguf -c 32768 -b 2048 -ub 512 -npp 16384 -ntg 128,128 -npl 1 -ngl 99 -fa

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 16, n_threads_batch = 16

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
16384 128 1 16512 1.049 15616.24 0.745 171.91 1.794 9205.29
16384 128 1 16512 1.049 15613.28 0.726 176.30 1.775 9300.50
llama-bench -m ..\qwen2.5-coder-7b-instruct-q4_0.gguf -p 512,4096 -pg 16384,128 -fa 0,1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl fa test t/s
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp512 11347.19 ± 207.26
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp4096 8711.01 ± 9.87
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 tg128 161.06 ± 0.16
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 0 pp16384+tg128 3296.54 ± 2.66
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp512 12862.87 ± 75.49
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp4096 12379.35 ± 82.83
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 tg128 165.19 ± 0.41
qwen2 7B Q4_0 4.12 GiB 7.62 B CUDA 99 1 pp16384+tg128 6254.35 ± 4.65

build: 741b729 (4825)

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this is a good find. In terms of correctness this looks good to me. However, I think parallel_blocks should be changed to a runtime parameter. I originally made it a template parameter because it was only ever set in conjunction with cols_per_block but since it's only used outside of the loop over K->ne[1] it should be fine. The advantage would be that this would simplify your implementation and avoid the need for an increase in template specializations that would need to be compiled. I'll make a PR for this.

Consider also taking a look at the FP16 kernel; it is currently not really being used since the default FA precision is FP32 but it would likely also benefit.

@gaugarg-nv
Copy link
Contributor Author

Thank you, this is a good find. In terms of correctness this looks good to me. However, I think parallel_blocks should be changed to a runtime parameter. I originally made it a template parameter because it was only ever set in conjunction with cols_per_block but since it's only used outside of the loop over K->ne[1] it should be fine. The advantage would be that this would simplify your implementation and avoid the need for an increase in template specializations that would need to be compiled. I'll make a PR for this.

Consider also taking a look at the FP16 kernel; it is currently not really being used since the default FA precision is FP32 but it would likely also benefit.

Thanks @JohannesGaessler . I agree changing parallel_blocks to a runtime parameter makes sense and I too was thinking of the same.
I didn't do it as part of this PR as different attention kernels all share the same function pointer definition of fattn_kernel_t and it would require making changes in all the attention kernels. Thanks for offering to make PR for this.

I'll wait for your PR and then rebase my change on top of it. I'll also update the fp16 kernel once your PR is merged. As currently fp16 kernel is not being used, I didn't want to pay the cost of template instantiation for that kernel.

@JohannesGaessler
Copy link
Collaborator

I pushed a patch to jg/cuda-fa-np-runtime. In the function ggml_cuda_launch_fattn it should be possible to define parallel_blocks in a basically arbitrary way at runtime without affecting program correctness. Ideally it is possible to do the implementation in a generic way that does not result in a performance regression for any combination of tensor shapes and hardware - though since the previously used heuristic was rather crude I think there's a high chance this can be done.

With the code I pushed there would very likely be a performance regression for V100s and AMD GPUs with tensor cores; we should check the performance prior to merging. @IMbackK your assistance would be appreciated when the time comes.

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 7, 2025

@JohannesGaessler sure I would be happy to, just infirm me when it is ready.

@gaugarg-nv gaugarg-nv force-pushed the flash_decoding_improvement branch from a8a7175 to d83b0d0 Compare March 7, 2025 10:06
@github-actions github-actions bot added the testing Everything test related label Mar 7, 2025
@gaugarg-nv
Copy link
Contributor Author

gaugarg-nv commented Mar 7, 2025

I pushed a patch to jg/cuda-fa-np-runtime. In the function ggml_cuda_launch_fattn it should be possible to define parallel_blocks in a basically arbitrary way at runtime without affecting program correctness. Ideally it is possible to do the implementation in a generic way that does not result in a performance regression for any combination of tensor shapes and hardware - though since the previously used heuristic was rather crude I think there's a high chance this can be done.

With the code I pushed there would very likely be a performance regression for V100s and AMD GPUs with tensor cores; we should check the performance prior to merging. @IMbackK your assistance would be appreciated when the time comes.

Thanks @JohannesGaessler for this change. I have now rebased my change on top of your branch. I have confirmed that performance results are same as what I reported in PR description on RTX 4090.

@gaugarg-nv
Copy link
Contributor Author

@JohannesGaessler What are the next steps here? Please let me know if I can help with some of the testing. Are there any specific GPU architectures that you would like to see tested?

@JohannesGaessler
Copy link
Collaborator

The next step is testing. Sorry, I forgot. I'm currently busy but I should get to it by Friday at the latest.

@Beinsezii
Copy link
Contributor

Interestingly there does appear to be a small but measurable TG improvement on ROCm gfx11

bin/llama-bench -m Meta-Llama-3.1-8B-Instruct-F16.gguf -ngl 99 -p 8192 -n 1024 -r 2 -fa 0,1
master

model size params backend ngl fa test t/s
llama 8B F16 14.96 GiB 8.03 B ROCm 99 0 pp8192 1824.57 ± 3.13
llama 8B F16 14.96 GiB 8.03 B ROCm 99 0 tg1024 47.36 ± 0.00
llama 8B F16 14.96 GiB 8.03 B ROCm 99 1 pp8192 2749.14 ± 0.03
llama 8B F16 14.96 GiB 8.03 B ROCm 99 1 tg1024 47.64 ± 0.00

gaugarg-nv/flash_decoding_improvement

model size params backend ngl fa test t/s
llama 8B F16 14.96 GiB 8.03 B ROCm 99 0 pp8192 1829.28 ± 0.76
llama 8B F16 14.96 GiB 8.03 B ROCm 99 0 tg1024 47.41 ± 0.01
llama 8B F16 14.96 GiB 8.03 B ROCm 99 1 pp8192 2728.66 ± 0.57
llama 8B F16 14.96 GiB 8.03 B ROCm 99 1 tg1024 48.05 ± 0.01

Might be larger if there's a way to measure with context prefilled to 30k or so.

@gaugarg-nv
Copy link
Contributor Author

gaugarg-nv commented Mar 13, 2025

Might be larger if there's a way to measure with context prefilled to 30k or so.

@Beinsezii thanks for testing. Yes, the improvements are expected more if we prefill context to large sequence.

The way I test this is how I have described in PR description:

  • Run llama-bench.exe -m <model> -p 10,100,1000,2000,4000,8000,16000 -n 0 -pg 10,200 -pg 100,200 -pg 1000,200 -pg 2000,200 -pg 4000,200 -pg 8000,200 -pg 16000,200 -o csv -fa 1
  • This keeps the OSL to constant of 200 but vary ISL from small to large seq length.
  • This provides avg_ns value for a. the prompt phase and b. the prompt + gen phase combined. Subtracting #a) from #b) gives time spent in only gen phase for varying ISL.

@Beinsezii
Copy link
Contributor

Beinsezii commented Mar 13, 2025

  • This keeps the OSL to constant of 200 but vary ISL from small to large seq length.

Ah, I didn't realize that's what -pg meant.

In that case, for llama-8b-q4k my XTX is at ≈39.5 tg t/s @ 16000pp+200tg on master, and ≈52.3 t/s on this PR, which at ≈+32% is way more substantial than I was expecting. test-backend-ops is all green too, so looking forward to the merge

@JohannesGaessler
Copy link
Collaborator

Sorry for the late reply, I'm seeing bad performance on Pascal. Against my expectation the root cause seems to be specifically the change where I made the number of parallel blocks a runtime parameter. On Ada Lovelace this is completely negligible but on Pascal the kernel runtime for some reason goes up by ~50%. Do you have an idea what could be causing this?

@JohannesGaessler
Copy link
Collaborator

I debugged the performance issues on Pascal and was able to fix them by changing the way the grid indices are mapped to tensor slices; this seems to improve Pascal performance in general. I pushed a new version to jg/cuda-fa-np-runtime. After you rebase on that version we can do final performance testing and assuming there are no regressions this PR will then be good to merge from my end.

@gaugarg-nv gaugarg-nv force-pushed the flash_decoding_improvement branch from d83b0d0 to b6e067b Compare March 16, 2025 15:22
@gaugarg-nv
Copy link
Contributor Author

I debugged the performance issues on Pascal and was able to fix them by changing the way the grid indices are mapped to tensor slices; this seems to improve Pascal performance in general. I pushed a new version to jg/cuda-fa-np-runtime. After you rebase on that version we can do final performance testing and assuming there are no regressions this PR will then be good to merge from my end.

Thanks @JohannesGaessler . I have now rebased my changes on top of jg/cuda-fa-np-runtime.

@gaugarg-nv
Copy link
Contributor Author

Did some more testing on RTX 4090. I see around 10% improvement on Llama-3b and Qwen-1.5b model for high seq length.

  ISL OSL Master: Gen phase tok/sec PR: Gen phase tok/sec Speed-up
llama 3B Q4_K - Medium 10 200 269.8302524 266.2338712 0.986672
  100 200 268.4936757 263.709292 0.982181
  1000 200 252.4435713 254.9205103 1.009812
  10000 200 170.5515393 189.4690656 1.11092
           
qwen2 1.5B Q4_K - Medium        
  10 200 344.4271769 360.4276128 1.046455
  100 200 342.1115136 360.8377528 1.054737
  1000 200 335.4593024 353.6292346 1.054164
  10000 200 264.7008325 288.7548641 1.090873

JohannesGaessler and others added 2 commits March 19, 2025 17:00
Adds the following optimizations to the CUDA flash decoding code:

- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value.
- Prefer vector flash attention kernels over MMA kernel for BS=1

This results in upto 15% perf improvement in gen phase throughput for large seq lengths.

Issue: ggml-org#12182
@gaugarg-nv gaugarg-nv force-pushed the flash_decoding_improvement branch from b6e067b to aa5aa01 Compare March 19, 2025 11:38
@gaugarg-nv
Copy link
Contributor Author

Rebased the changes on top of master as there were some conflicts.

Also benchmarked performance on GTX 1080. This is what I get:

  ISL OSL Master: Gen phase tok/sec PR: Gen phase tok/sec Speed-up
llama 3B Q4_K - Medium 100 100 64.40901405 66.44957568 1.031681
  1000 100 57.97212522 60.09357042 1.036594
  4000 100 48.94937194 49.50312464 1.011313
           
qwen2 1.5B Q4_K - Medium        
  100 100 82.36570957 84.50382656 1.025959
  1000 100 79.3190166 79.74659464 1.005391
  4000 100 68.77001876 75.35970213 1.095822

There are no regressions and we see a decent gain with Qwen 1.5B model.

@JohannesGaessler do you think we can merge this change?

@JohannesGaessler
Copy link
Collaborator

There are performance regressions, just not for the GPUs you've tested. Your implementation does not consider tail effects so on the last wave the GPU utilization can be bad depending on the SM count vs. the input size. This is noticeable with an RTX 3090 or a P40. Since I have those GPUs for testing I'm working on a fix myself and will provide it soon.

@gaugarg-nv
Copy link
Contributor Author

There are performance regressions, just not for the GPUs you've tested. Your implementation does not consider tail effects so on the last wave the GPU utilization can be bad depending on the SM count vs. the input size. This is noticeable with an RTX 3090 or a P40. Since I have those GPUs for testing I'm working on a fix myself and will provide it soon.

I can try to get access to a 3090. Can you please paste the exact llama-bench command showing regression?

@JohannesGaessler
Copy link
Collaborator

I am testing the performance with commands like this:

export model_name=llama_3-8b && export quantization=q4_0
./bench --model models/opt/${model_name}-${quantization}.gguf -r 1 -n 0 -fa 1 -p 16384 -ub 1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384 -o sql|sqlite3 llama-bench.sqlite
python3 scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b 3d652bfddfba09022525067e672c3c145c074649 -c 2a38d5c8c4843c413bb65d3724b3256c9a4973df

There are GPUs for which the code path with batch sizes > 1 changes but for an RTX 3090 only batch size 1 is relevant. But as I said, I will push something soon. I would suggest you wait until then and see if you can improve upon that version.

@gaugarg-nv
Copy link
Contributor Author

But as I said, I will push something soon. I would suggest you wait until then and see if you can improve upon that version.

Sure, thanks. I will wait for your changes to be available.

@JohannesGaessler
Copy link
Collaborator

I pushed a commit to your branch which iterates over potential higher values for parallel_blocks and tries to reduce inefficiencies from tail effects. Also for Ampere and older the mma kernels are still used if applicable since I found them to perform better than the vector kernels even with the increased occupancy. I get the following performance:

GPU Model Microbatch size Test t/s 0fd8487 t/s 66d873b Speedup
P40 llama 8B Q4_0 1 pp16384 36.54 40.49 1.11
P40 llama 8B Q4_0 2 pp16384 47.44 52.33 1.10
P40 llama 8B Q4_0 4 pp16384 58.82 68.05 1.16
P40 llama 8B Q4_0 8 pp16384 73.07 81.42 1.11
P40 llama 8B Q4_0 16 pp16384 248.38 264.03 1.06
P40 llama 8B Q4_0 32 pp16384 337.84 360.44 1.07
P40 llama 8B Q4_0 64 pp16384 307.31 396.61 1.29
P40 llama 8B Q4_0 128 pp16384 377.11 420.65 1.12
P40 llama 8B Q4_0 256 pp16384 421.12 440.49 1.05
P40 llama 8B Q4_0 512 pp16384 443.85 448.23 1.01
P40 llama 8B Q4_0 1024 pp16384 438.29 444.09 1.01
P40 llama 8B Q4_0 2048 pp16384 427.09 435.72 1.02
P40 llama 8B Q4_0 4096 pp16384 428.06 433.74 1.01
P40 llama 8B Q4_0 8192 pp16384 426.87 433.55 1.02
P40 llama 8B Q4_0 16384 pp16384 427.76 435.24 1.02
RTX 3090 llama 8B Q4_0 1 pp16384 125.40 125.37 1.00
RTX 4090 llama 8B Q4_0 1 pp16384 146.66 151.44 1.03

So on the GPUs I've tested the performance for is now consistently better than on master. @ggerganov if you could check the performance on your V100 it would be appreciated. @IMbackK please check the performance for an AMD GPU using the WMMA kernels. In both cases both the performance for batch size 1 and batch sizes >> 1 may have changed. Also since I have now written a non-negligible percentage of the code in this PR I would appreciate it if someone else could add a review.

@Beinsezii
Copy link
Contributor

Beinsezii commented Mar 19, 2025

Small 7900 XTX benches on new head using GGML_HIP_ROCWMMA_FATTN=1

bin/llama-bench -m Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -ngl 99 -p 16000 -n 0 -pg 16000,200 -r 3 -fa 1

head n_prompt n_gen pp_ts tg_ts
master 16000 200 1981.93 39.48
66d873b 16000 200 1994.49 61.57

Seems like the additional regression fixes have almost doubled my effective increase since the last bench #12183 (comment)

it's kind of funny because of the difference in util this PR changes the tone of my coil whine

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 19, 2025

@JohannesGaessler I see no change or slight improvements across the board on CDNA with WMMA

PR:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |        1 |  1 |        pp1024 |         88.55 ± 5.33 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |       16 |  1 |        pp1024 |        201.82 ± 0.12 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      256 |  1 |        pp1024 |       2364.47 ± 2.46 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      512 |  1 |        pp1024 |       3027.19 ± 2.14 |

Master:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |        1 |  1 |        pp1024 |         85.60 ± 5.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |       16 |  1 |        pp1024 |        200.71 ± 0.03 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      256 |  1 |        pp1024 |       2368.71 ± 1.82 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      512 |  1 |        pp1024 |       3043.48 ± 3.30 |

@ggerganov
Copy link
Member

With V100 it seems there is not much difference - is this within expectations?

GGML_CUDA=1 ./scripts/compare-commits.sh master pr/12183 -m models-mnt/open-llama/7B-v2/ggml-model-q4_0.gguf -m models-mnt/open-llama/7B-v2/ggml-model-q8_0.gguf -m models-mnt/open-llama/7B-v2/ggml-model-q4_k.gguf -fa 1 -n 0 -p 4096 -ub 1,2,4,8,16,32,64,128,256,512 -r 1 -ngl 99
Model Microbatch size Test t/s master t/s pr/12183 Speedup
llama 7B Q4_0 1 pp4096 120.11 119.08 0.99
llama 7B Q4_0 2 pp4096 207.93 211.72 1.02
llama 7B Q4_0 4 pp4096 313.78 311.21 0.99
llama 7B Q4_0 8 pp4096 487.66 493.68 1.01
llama 7B Q4_0 16 pp4096 756.47 763.94 1.01
llama 7B Q4_0 32 pp4096 1111.01 1114.67 1.00
llama 7B Q4_0 64 pp4096 613.40 607.49 0.99
llama 7B Q4_0 128 pp4096 1123.29 1102.10 0.98
llama 7B Q4_0 256 pp4096 1668.56 1672.35 1.00
llama 7B Q4_0 512 pp4096 2138.19 2157.27 1.01
llama 7B Q4_K_M 1 pp4096 112.75 112.06 0.99
llama 7B Q4_K_M 2 pp4096 177.46 177.55 1.00
llama 7B Q4_K_M 4 pp4096 258.02 259.64 1.01
llama 7B Q4_K_M 8 pp4096 310.70 312.40 1.01
llama 7B Q4_K_M 16 pp4096 690.29 693.88 1.01
llama 7B Q4_K_M 32 pp4096 998.44 995.80 1.00
llama 7B Q4_K_M 64 pp4096 617.86 611.84 0.99
llama 7B Q4_K_M 128 pp4096 1131.27 1111.56 0.98
llama 7B Q4_K_M 256 pp4096 1679.40 1679.73 1.00
llama 7B Q4_K_M 512 pp4096 2144.94 2161.94 1.01
llama 7B Q8_0 1 pp4096 83.94 83.60 1.00
llama 7B Q8_0 2 pp4096 148.76 150.36 1.01
llama 7B Q8_0 4 pp4096 234.13 236.20 1.01
llama 7B Q8_0 8 pp4096 382.40 385.34 1.01
llama 7B Q8_0 16 pp4096 642.34 644.91 1.00
llama 7B Q8_0 32 pp4096 947.97 949.16 1.00
llama 7B Q8_0 64 pp4096 817.02 809.33 0.99
llama 7B Q8_0 128 pp4096 1457.01 1426.50 0.98
llama 7B Q8_0 256 pp4096 2008.42 2011.89 1.00
llama 7B Q8_0 512 pp4096 2400.06 2426.38 1.01

@JohannesGaessler
Copy link
Collaborator

If there is no change that's fine, I just want to make sure there is no regression. +-1% can just be run-to-run variation.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my end this PR would now be good to merge.

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 19, 2025

./bin/test-backend-ops -o FLASH_ATTN_EXT is also fine on CDNA so no correctness issues either.

@Beinsezii
Copy link
Contributor

./bin/test-backend-ops -o FLASH_ATTN_EXT is also fine on CDNA so no correctness issues either.

ditto rdna3

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im also happy with the code changes.

@IMbackK IMbackK merged commit 517b5dd into ggml-org:master Mar 19, 2025
51 of 52 checks passed
@gaugarg-nv gaugarg-nv deleted the flash_decoding_improvement branch March 20, 2025 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants