Skip to content

Bug: Flash Attention performs worse under ROCM #10439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Mushoz opened this issue Nov 20, 2024 · 49 comments
Closed

Bug: Flash Attention performs worse under ROCM #10439

Mushoz opened this issue Nov 20, 2024 · 49 comments
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@Mushoz
Copy link

Mushoz commented Nov 20, 2024

What happened?

Turning on flash attention degrades the performance when used under ROCM (at least it does with a 7900 xtx). Using batched bench, the degradation is quite minor at a batchsize of 1.

prompt processing: 461 -> 434
token generation: 24.26 -> 23.84

However, when running multiple batches of requests at the same time, the effect is MUCH more pronounced. Especially with batch sizes of 16 the difference is massive:

prompt processing: 678 -> 375
token generation: 169.65 -> 86.87

Flash Attention is needed to be able to use quantization for the KV-cache, but the performance hit is drastic. Can this be fixed?

Name and Version

build: 4123 (2eb76b2) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

@Mushoz Mushoz added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels Nov 20, 2024
@JohannesGaessler
Copy link
Collaborator

It's a known issue and cause by the HIP port of the CUDA FlashAttention kernel for large batch sizes having extremely poor performance (so the kernel optimized for small batch sizes is used instead). With the current code this issue cannot be fixed.

@Mushoz
Copy link
Author

Mushoz commented Nov 20, 2024

Are there any plans to rewrite that code to be optimized for ROCM instead of a CUDA port?

@JohannesGaessler
Copy link
Collaborator

No. There currently isn't a llama.cpp/GGML dev working specifically on AMD performance or even support. I am writing a lot of CUDA code but the extent of effort that I am willing to invest is make sure that the HIP port doesn't break and determining the comparatively best code paths for AMD.

@Mushoz
Copy link
Author

Mushoz commented Nov 21, 2024

Why does the performance also regress when FA is enabled when not using batching? That's also a bit weird given the fact it's optimized for small batch sizes.

Large batch sizes hasn't really been that important in the past. Hobbyists usually do not use parallelism as they have single-user use cases most of the times. But with the introduction of speculative decoding which will hopefully land in Llama server in the future as well, performance for larger batch sizes will become important even for single user use cases.

What will it take for optimizations for AMD to be considered? Will it take someone to join the project to develop and maintain these optimizations? Or might a shift towards more optimization for AMD also be made when (if?) AMD starts to become more popular among the end users?

@JohannesGaessler
Copy link
Collaborator

Why does the performance also regress when FA is enabled when not using batching? That's also a bit weird given the fact it's optimized for small batch sizes.

The code is optimized for small batch sizes and NVIDIA GPUs, if you use HIP to translate it for AMD you still pay a performance penalty vs. using an external BLAS library optimized for AMD.

What will it take for optimizations for AMD to be considered? Will it take someone to join the project to develop and maintain these optimizations? Or might a shift towards more optimization for AMD also be made when (if?) AMD starts to become more popular among the end users?

My personal goal with llama.cpp is to reduce the cost at which the largest models can be run at reasonable speeds. As of right now I think there simply isn't any AMD hardware that would be worth buying over second-hand NVIDIA hardware (RTX 3090/P40). I have seen some second-hand AMD GPUs such as the Mi 60 at very competitive prices on ebay but unfortunately the supply seems to be extremely limited. If AMD were to release a consumer GPU with a high VRAM capacity at a sufficiently low price I would start optimizing performance for that GPU (even though the AMD dev tools are worse or nonexistent).

If a new dev were to join the project with an interest in improving AMD performance I would be happy to assist them.

@Mushoz
Copy link
Author

Mushoz commented Nov 21, 2024

Understandable, thanks for the detailed explanation! Hoping to see other devs join the project to optimize AMD performance then :)

By the way, do you think (second hand) 7900 xtx could potentially be cost competitive to second hand 3090 and 4090 GPUs? The memory bandwidth is very similar between those GPUs.

@JohannesGaessler
Copy link
Collaborator

It has become better compared to 1-2 years ago but at least in my region (Germany) there currently don't really seem to be good second-hand offers for RX 7900 XTX cards.

@hjc4869
Copy link
Contributor

hjc4869 commented Nov 22, 2024

You may try my forked branch that enables rocWMMA on top of the current CUDA WMMA flash attention implementation, and I'm actively rebasing the latest master branch for my own usage: https://github.com/hjc4869/llama.cpp

From my own testing it improves performance by quite a bit on RDNA3 with higher batch size, though still not optimal comparing to equivalent NVIDIA GPUs.

Flash attention (master branch)

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.762 290.58 9.131 14.02 10.893 58.75
512 128 8 5120 24.958 164.12 28.862 35.48 53.820 95.13
512 128 16 10240 76.425 107.19 85.347 24.00 161.771 63.30

Flash attention (WMMA patched)

make GGML_HIPBLAS=1 GGML_CUDA_FA_ALL_QUANTS=1 AMDGPU_TARGETS=gfx1100,gfx1101 -j64

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.601 319.81 9.180 13.94 10.781 59.36
512 128 8 5120 15.654 261.65 29.136 35.15 44.790 114.31
512 128 16 10240 37.842 216.48 38.110 53.74 75.952 134.82

Flash attention off

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.584 323.17 8.691 14.73 10.276 62.28
512 128 8 5120 15.569 263.09 26.352 38.86 41.921 122.13
512 128 16 10240 37.198 220.23 50.002 40.96 87.200 117.43

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Wow! I am seeing very good results here. Some observations:

  1. With FA turned off, performance between your branch and master is identical. So that's good, it did not introduce any regressions in the FA turned off case.
  2. With the non-batched bench, the regression observed when turning on FA for prompt processing is removed. Turning on FA gives about the same tokens/sec during PP as when it's turned off. In the master branch, a big performance regression is noticed.
  3. With the non-batched bench, the regression when turning on FA for the tokens/sec during generation is still observed, but it has not gotten worse. So identical performance to the master branch there.
  4. For the batched bench, for batch sizes of 1, 2, 4 and 8, identical performance is observed with FA turned on compared to the master branch. Including the drop-off in performance when going from batch size 4 to batch size 8. So there's probably same performance to be gained versus the non-FA case here. But at least this branch did not regress compared to master.
  5. For batch sizes 16 and 32, turning on FA massively boosts performance, whereas on master turning on FA would actually REDUCE performance. Very good improvement there!
  6. All in all, I did not experience a single downside compared to master branch. Very good job!

A question: I saw that your branch also introduced the option to compile with GGML_CUDA_FA_ALL_QUANTS. What does this do? I did not enable this, but is it worth enabling?

Master branch:

Single Bench

Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no

model size params backend ngl fa test t/s
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 pp512 744.82 ± 2.17
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 tg128 27.19 ± 0.06
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 pp512 639.80 ± 0.63
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 tg128 25.64 ± 0.01

build: a5e4759 (4150)

Batched bench FA off

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.550 697.74 4.969 25.76 5.519 92.77
384 128 2 1024 1.092 703.51 6.267 40.85 7.359 139.16
384 128 4 2048 2.200 698.23 9.225 55.50 11.425 179.26
384 128 8 4096 4.813 638.29 16.997 60.25 21.810 187.81
384 128 16 8192 11.088 554.12 17.008 120.41 28.096 291.57

ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error
ROCm error: out of memory

Batched bench FA on

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.635 604.85 5.115 25.02 5.750 89.04
384 128 2 1024 1.276 601.72 6.224 41.13 7.501 136.52
384 128 4 2048 2.819 544.78 9.554 53.59 12.373 165.52
384 128 8 4096 7.176 428.08 19.747 51.86 26.923 152.14
384 128 16 8192 20.508 299.59 40.661 50.37 61.168 133.93
384 128 32 16384 66.030 186.10 84.869 48.26 150.900 108.58

Your branch

Single bench

Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no

model size params backend ngl fa test t/s
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 pp512 744.91 ± 1.40
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 tg128 27.14 ± 0.07
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 pp512 739.81 ± 1.30
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 tg128 25.60 ± 0.02

build: 8739ed4 (4154)

Batched bench FA off

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.557 689.97 4.983 25.69 5.539 92.43
384 128 2 1024 1.091 703.82 6.615 38.70 7.706 132.88
384 128 4 2048 2.190 701.36 9.243 55.39 11.433 179.13
384 128 8 4096 4.795 640.71 17.012 60.19 21.807 187.83
384 128 16 8192 11.054 555.84 17.021 120.33 28.074 291.80

ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error
ROCm error: out of memory
current device: 0, in function alloc at ggml/src/ggml-cuda/ggml-cuda.cu:275
ggml_cuda_device_malloc(&ptr, look_ahead_size, device)

Batched bench FA on

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.579 663.20 5.114 25.03 5.693 89.93
384 128 2 1024 1.100 698.49 6.243 41.01 7.342 139.47
384 128 4 2048 2.170 707.78 9.601 53.33 11.771 173.99
384 128 8 4096 4.749 646.81 19.833 51.63 24.582 166.62
384 128 16 8192 11.070 555.00 14.309 143.13 25.379 322.78
384 128 32 16384 28.232 435.25 31.475 130.14 59.707 274.41

@hjc4869
Copy link
Contributor

hjc4869 commented Nov 22, 2024

GGML_CUDA_FA_ALL_QUANTS is not related to the issue. It's to compile FA kernel for all the KV cache quantization combinations. For example if you want -ctk q4_0 in combination with -ctv q8_0 (different quantization type), or those rarely used ones like q4_1, q5_0, q5_1.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Thanks for your explanation. Do you have any intention of trying to upstream these FA changes?

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Didn't that comment in that PR mention they would be open to merging improvements, as long as someone would commit to maintaining that code? Not sure if that would be an option for you.

Is there some way we can reach out to AMD and convince them to commit to performance improvements and maintaining them for llamacpp?

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

That's unfortunate to hear. I really think the 7900xtx could be competitive both on cost and performance. It just needs the support to get there.

The situation has already improved massively compared to 1-2 years ago though. So perhaps in the future we'll get there.

@ccbadd
Copy link

ccbadd commented Nov 22, 2024

It has become better compared to 1-2 years ago but at least in my region (Germany) there currently don't really seem to be good second-hand offers for RX 7900 XTX cards.>

A new 7900xtx is pretty much the same price (or reasonably close) as a used 3090 so doesn't that mean AMD cards are competitive? That used NV market is going to dry up at some point and NV does not care to make affordable cards that have enough vram for llms.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

A new 7900xtx is slightly more expensive, and performing slightly worse than a used 3090. With performance optimizations it would be competitive, but right now the 3090 has the advantage. It's a bit of a chicken and egg problem to be honest.

@sirus20x6
Copy link

@JohannesGaessler do you think you could point us to some parts of the code that could benefit the most from rocm additions? I would be interested in learning more low level gpu programming and would like a test project

@JohannesGaessler
Copy link
Collaborator

Generally speaking the HIP port of the CUDA code struggles the most with compute-intensive kernels, i.e. the kernels that are used for prompt processing where many tokens can be worked on in parallel. As of right now these are the mul_mat_q (large batch matrix multiplication using quantized data without dequantization to VRAM) and the FlashAttention kernels. There would be a lot of work to do for full ROCm support; I would be happy to talk with you in detail about what could be done depending on how much time and energy you're willing to commit.

@sirus20x6
Copy link

sirus20x6 commented Nov 26, 2024

Generally speaking the HIP port of the CUDA code struggles the most with compute-intensive kernels, i.e. the kernels that are used for prompt processing where many tokens can be worked on in parallel. As of right now these are the mul_mat_q (large batch matrix multiplication using quantized data without dequantization to VRAM) and the FlashAttention kernels. There would be a lot of work to do for full ROCm support; I would be happy to talk with you in detail about what could be done depending on how much time and energy you're willing to commit.

I took a look at https://github.com/hjc4869/llama.cpp and it seems what should be done is to create defines for a new ggml-rocm path. reimplement everything piece by piece in the ggml-cuda folder while using hipified cuda as a fallback until everything is replaced. I know the wavefront size is different on amd, and somehow I will need to benchmark the differences. sound about right?

@Mushoz
Copy link
Author

Mushoz commented Nov 27, 2024

@sirus20x6 Let me know if/when you start working on this. Would love to help test. I will create some graphs later to showcase how the 7900xtx currently scales with batchsize. That way, it might be easier to identify on which kernels to start working.

@JohannesGaessler
Copy link
Collaborator

@sirus20x6 In terms of performance basically the only kernels that really matter are matrix multiplication and FlashAttention (which is basically two matrix multiplications merged with a softmax). My advice to you would be to start with ROCm kernels that can be used instead of those in mul_mat_vec_q.cu and mul_mat_vec.cu. Those are the kernels for matrix-vector multiplication and are used when generating new tokens (for mul_mat_vec.cu you should first check whether there is a suitable alternative in the BLAS library). Compared to e.g. the general matrix-matrix multiplication kernels in mul_mat_q.cuh the matrix-vector kernels are much simpler (but for the matrix-vector kernels the HIP port also works comparatively better). I am currently not familiar with ROCm but if there is another dev interested I would be willing to learn in order to do code review and help with maintenance.

Medium-term I will add training support to llama.cpp, I think a major use case will be training LoRAs on top of quantized models. The current kernels in mul_mat_q.cuh are however not suitable for the backward pass where the weight matrix is transposed. Instead I think it will be necessary to write kernels that dequantize the weights to FP16/BF16 on-the-fly and then use FP16/BF16 tensor cores. You could also use the same kernels for the non-transposed case in the forward pass so it may make more sense to skip the int8 kernels in mul_mat_q.cuh. (For the beginning you can also just write simple kernels that dequantize the weights to FP16 and write them to VRAM in order to use a BLAS library but then you need to use a large batch size to get a low overhead per token.)

If you are willing to write an entire ROCm GGML backend the best person to ask for help would be @slaren since he has the most in-depth knowledge of the corresponding code.

@sirus20x6
Copy link

I was a little confused at fist but I see they must have been renamed mmv.cu/h and mmq.cu/h I will take a look and see if I can make sense of everything :)

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Nov 27, 2024

No, I just misremembered the names of the files, sorry. mmvq.cu and mmv.cu have the comparatively easy code for matrix-vector multiplication, mmq.cuh has the code for matrix-matrix multiplication.

@thamwangjun
Copy link

thamwangjun commented Nov 28, 2024

I have been superficially looking at this for a while, because I have a AMD ROCm setup of 4x MI100s with Infinity Fabric between them. AMD has provided a kernel for ROCm, and from what little I know of this kernel, it already has an implementation of Flash Attention 2 in it: https://github.com/ROCm/composable_kernel

@thamwangjun
Copy link

I have been superficially looking at this for a while, because I have a AMD ROCm setup of 4x MI100s with Infinity Fabric between them. AMD has provided a kernel for ROCm, and from what little I know of this kernel, it already has an implementation of Flash Attention 2 in it: https://github.com/ROCm/composable_kernel

Not sure if it is a good idea to use the C++ APIs of this kernel (pre-compiled) for this project?

@thamwangjun
Copy link

@JohannesGaessler I am thinking of jumping onboard (because I will be fully available in the coming weeks, and I have a vested interest in this). However I feel I do have a significant knowledge gap in terms of implementing this, and I was also thinking of importing code from https://github.com/ROCm/composable_kernel, since it is an official example from AMD of going about implementing these kernels. Let me know what you think, if you are able to help with code reviews and such. Thanks.

@JohannesGaessler
Copy link
Collaborator

As outlined in the coding guidelines:

Avoid adding third-party dependencies, extra files, extra headers, etc.

How many lines of imported code are we talking about? Will you be available long-term for maintenance?

@thamwangjun
Copy link

@JohannesGaessler I plan on importing the bare minimum it to get it working initially, and then rewrite/refactor all of them to more appropriately suit this project and its guidelines.

I will be available for the foreseeable future (because I have a machine of 4xMI100 + InfinityFabric, and it is very very unlikely I will get a third machine soon), I too plan on adding documentation as I go about this, because I appreciate being able to easily pickup a project.

@Mushoz
Copy link
Author

Mushoz commented Dec 3, 2024

@thamwangjun Since MI100 is CDNA based, will these improvements only target that architecture? Or are improvements for RDNA3 based cards (eg 7900xtx) to be expected as well?

@JohannesGaessler
Copy link
Collaborator

@thamwangjun if you make a PR with a working implementation and pledge to maintain it "for the foreseeable future" I would likely be willing to merge it. However, if the implementation becomes unmaintained and broken and no one steps up to maintain it my position will be that the code should be removed again, as was done with OpenCL. Of course, if the implementation is simple, self-contained, and well-documented it will be more likely that someone will be willing to maintain it (including me).

@thamwangjun
Copy link

@Mushoz I'm sorry but at this point I will only start work with only CDNA support in mind, because that is what AMD's own flash attention support for now. But I have an old RDNA1 card and RNDA3 integrated graphics in my laptop, so who knows what the future may hold 🤔

What I am committed to right now is getting my MI100 setup running faster. I feel that the HIP port of CUDA is leaving a lot of potential performance benefits behind.

@thamwangjun
Copy link

@JohannesGaessler thanks, I understand it is not reasonable to leave broken features in the project. It is better to remove them if broken/unmaintained.
I think I have work cut out for me, I will first be starting on the matrix multiplications, by referencing how AMD does it in their own kernel.

@sorasoras
Copy link

@Mushoz I'm sorry but at this point I will only start work with only CDNA support in mind, because that is what AMD's own flash attention support for now. But I have an old RDNA1 card and RNDA3 integrated graphics in my laptop, so who knows what the future may hold 🤔

What I am committed to right now is getting my MI100 setup running faster. I feel that the HIP port of CUDA is leaving a lot of potential performance benefits behind.

I think AMD's own flash attention do support RDNA3 forward pass for now.

and, there is fa2 triton implement by third party repo at
https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal
it does support training on RDNA3 ie fwd bwd pass.

@Mushoz
Copy link
Author

Mushoz commented Dec 6, 2024

This might be useful to reference while working on a solution: Dao-AILab/flash-attention#1203

This PR that has been merged has FA2 support for CDNA & RDNA.

@tbocek
Copy link

tbocek commented Dec 14, 2024

I just tested the rocWMMA enabled llama.cpp from https://github.com/hjc4869/llama.cpp with rocm-6.3.0. Since I use the opencl-amd-dev in Arch, it is already included, no additional installation. Here is my command to test (on my 7900 xtx): ./build/bin/llama-batched-bench --model /mnt/models/Qwen2.5-Coder-32B-Instruct-abliterated-IQ4_XS.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240 -ngl 999

ggerganov/llama.cpp

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.754 679.43 4.674 27.38 5.428 117.91
512 128 8 5120 10.547 388.34 14.202 72.10 24.749 206.87
512 128 16 10240 31.898 256.82 49.061 41.74 80.959 126.48

hjc4869/llama.cpp

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.645 793.55 4.722 27.11 5.367 119.24
512 128 8 5120 5.610 730.14 14.338 71.42 19.948 256.67
512 128 16 10240 12.379 661.74 12.777 160.28 25.157 407.05

@danielzgtg
Copy link
Contributor

I am on a 6650 XT, which doesn't have WMMA. I don't know what speedup is possible, but at the bare minimum, the flash attention implementation should stop making performance worse.

@Headcrabed
Copy link

@JohannesGaessler thanks, I understand it is not reasonable to leave broken features in the project. It is better to remove them if broken/unmaintained. I think I have work cut out for me, I will first be starting on the matrix multiplications, by referencing how AMD does it in their own kernel.

@thamwangjun Hello, are you still working on this?

@adelj88
Copy link

adelj88 commented Feb 17, 2025

@JohannesGaessler I'd be interested in contributing to any initiative to optimise the kernels for ROCm, though I can only do so in a part-time capacity. It would be great to discuss how we could set something up that others can contribute to as well (or pick up from if I somehow disappear off the face of the planet); as well as whether it would be better to have a fork or a branch in the main repository. Let me know if email works

I see a number of other optimisation opportunities, such as using DPP intrinsics for the warp reduction operations instead of the __shfl_down calls. @thamwangjun I can also help out with the matrix multiplication - I may not have direct experience with composable_kernel, but I have some experience with writing an optimised fp16 gemm kernel that uses WMMA.

@JohannesGaessler
Copy link
Collaborator

I will happily explain to you what would need to be done for better ROCm performance and I would be willing to learn ROCm in order to review the code. Contact via email is generally fine for me, right now my private email [email protected] is not functional though. You can either wait a few days until it will presumably work again or you can contact me under [email protected] or [email protected].

I have some experience with writing an optimised fp16 gemm kernel that uses WMMA.

WMMA will in principle work for FlashAttention but not having a defined memory layout really hurts optimization. For NVIDIA GPUs I defined primitives in mma.cuh that expose the mma PTX instructions instead. For MMQ WMMA is a non-starter, without a defined memory layout the kernel is just not viable.

@thamwangjun
Copy link

thamwangjun commented Feb 19, 2025

@Headcrabed @adelj88 @JohannesGaessler Sorry for the late reply, I did several tests with lmperf to test flash attention between ROCm's provided flash-attention implementation and llama.cpp with Llama-3.2-1B on a single MI100.

The token generation speed is roughly the same, I was hoping to use ROCm's provided flash-attention implementation as a baseline to compare to, but my findings makes it hard to use ROCM's implementation as a baseline.

Can someone else help validate my findings and test performance of ROCM's FA compared to llama.cpp? I had the assumption that ROCM's FA would be the fastest since it is from AMD.

EDIT1: Add AMD FA assumption

@JohannesGaessler
Copy link
Collaborator

For a batch size of 1 which is what being used to generate tokens for a single user the HIP port of the llama.cpp CUDA code works well. It is specifically the code for batch sizes >> 1 that is used for prompt processing where the performance of the HIP port is bad.

@Mushoz
Copy link
Author

Mushoz commented Feb 19, 2025

For a batch size of 1 which is what being used to generate tokens for a single user the HIP port of the llama.cpp CUDA code works well. It is specifically the code for batch sizes >> 1 that is used for prompt processing where the performance of the HIP port is bad.

For token generation that is definitely not the case, at least not on my 7900xtx. I am getting 27 tokens/sec on Qwen2.5-32b with the ROCM backend with llama-bench. While The vulkan backend is giving me over 35 tokens/sec. I don't have the exact figures since it's been a while since I last checked, but happy to provide those if that is useful information. But given the rather large gap between Vulkan and ROCM, at least in terms of token generation there is a lot of room for improvement.

@thamwangjun
Copy link

Is it possible that Vulkan works better because RX 7900 XTX is geared towards graphics workloads? @Mushoz
I will try some tests specifically looking at prompt processing performance at > 1 batch size between ROCm's FA and llama.cpp's FA. @JohannesGaessler

@bjj
Copy link
Contributor

bjj commented Feb 27, 2025

@thamwangjun BTW, I put some MI100 numbers for llama.cpp FA in #12032 after I hacked up that patch enough to make it work on CDNA. It's likely that CDNA support won't land with that patch, but if you want any specific tests I can run them for you on my Mi100

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 3, 2025

This issue is fixed by #12032 for CDNA and RDNA3, while the performance could still be better, its not completely broken now. While not a universal fix i think this thread has outlived its usefulness.

@IMbackK IMbackK closed this as completed Mar 3, 2025
@tbocek
Copy link

tbocek commented Mar 5, 2025

Very nice, just ran the benchmark against current main with: ./build/bin/llama-batched-bench --model /mnt/models/Qwen2.5-Coder-32B-Instruct-abliterated-IQ4_XS.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240 -ngl 999

ggerganov/llama.cpp (Dez. 2024)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.754 679.43 4.674 27.38 5.428 117.91
512 128 8 5120 10.547 388.34 14.202 72.10 24.749 206.87
512 128 16 10240 31.898 256.82 49.061 41.74 80.959 126.48

hjc4869/llama.cpp (Dez. 2024)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.645 793.55 4.722 27.11 5.367 119.24
512 128 8 5120 5.610 730.14 14.338 71.42 19.948 256.67
512 128 16 10240 12.379 661.74 12.777 160.28 25.157 407.05

ggerganov/llama.cpp (Mar. 2025)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.694 738.08 4.132 30.98 4.826 132.62
512 128 8 5120 5.444 752.36 9.388 109.08 14.832 345.20
512 128 16 10240 11.931 686.64 12.151 168.55 24.081 425.23

@serhii-nakon
Copy link
Contributor

serhii-nakon commented May 24, 2025

@hjc4869 @JohannesGaessler Hello, I just tested new build option -DGGML_HIP_ROCWMMA_FATTN=ON with RX7900XTX and noticed that it speed up system inference on long context with FA and KV Cache by tens times

Here my full build option:

cmake -S . -B build -DGGML_CUDA_FA_ALL_QUANTS=ON -DGGML_HIP=ON -DGGML_HIP_ROCWMMA_FATTN=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON

Here my command line parameters for llama.cpp server:

    [
      "--server",
      "--port", "8000",
      "--host", "0.0.0.0",
      "-m", "/home/jenkins/models/devstralQ4_K_M.gguf",
      "--n-gpu-layers", "99",
      "-c", "100000",
      "--jinja",
      "--cache-type-k", "q8_0",
      "--cache-type-v", "q4_0",
      "--flash-attn",
      "--mlock"
    ]

Here without -DGGML_HIP_ROCWMMA_FATTN=ON

llama-cpp-1  | slot update_slots: id  0 | task 7757 | prompt processing progress, n_past = 71810, n_tokens = 363, progress = 0.033575
llama-cpp-1  | slot update_slots: id  0 | task 7757 | prompt done, n_past = 71810, n_tokens = 363
llama-cpp-1  | slot      release: id  0 | task 7757 | stop processing: n_past = 74041, truncated = 0
llama-cpp-1  | slot print_timing: id  0 | task 7757 | 
llama-cpp-1  | prompt eval time =   91368.33 ms /  2411 tokens (   37.90 ms per token,    26.39 tokens per second)
llama-cpp-1  |        eval time =  271793.54 ms /  2232 tokens (  121.77 ms per token,     8.21 tokens per second)
llama-cpp-1  |       total time =  363161.87 ms /  4643 tokens

Here with -DGGML_HIP_ROCWMMA_FATTN=ON

llama-cpp-1  | slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 76271, n_tokens = 495, progress = 1.000000
llama-cpp-1  | slot update_slots: id  0 | task 0 | prompt done, n_past = 76271, n_tokens = 495
llama-cpp-1  | slot      release: id  0 | task 0 | stop processing: n_past = 76558, truncated = 0
llama-cpp-1  | slot print_timing: id  0 | task 0 | 
llama-cpp-1  | prompt eval time =  164965.38 ms / 76271 tokens (    2.16 ms per token,   462.35 tokens per second)
llama-cpp-1  |        eval time =   23315.83 ms /   288 tokens (   80.96 ms per token,    12.35 tokens per second)
llama-cpp-1  |       total time =  188281.21 ms / 76559 tokens

So it really way faster.

Found strange bug or behavior:
Also I noticed some strange memory leak. After nearly 22000 context tokens processed - it started to use more and more VRAM until it crashed with error Out of memory

I just though - maybe better create discussion for AMD optimizations - I mean it will be way easier to discuss bugs and test new code base.
For example I would test new features even if it experimental - I use AI for my daily work like agents and I already have Docker builds from source - so for me it very easy to test new code features even with newest ROCm versions.

@JohannesGaessler
Copy link
Collaborator

Also I noticed some strange memory leak. After nearly 22000 context tokens processed - it started to use more and more VRAM until it crashed with error Out of memory

This is because of KV cache quantization. The FA kernel that is being used cannot use quantized data so it needs to be converted to FP16. And the size of the temporary buffer for the conversion scales with context size.

For example I would test new features even if it experimental

Let me be frank: the bottleneck is not at all testing code. The bottleneck is writing the code.

@serhii-nakon
Copy link
Contributor

serhii-nakon commented May 24, 2025

@JohannesGaessler Does it possible to reserve temporary buffer before actual inference - I mean for users maybe would be better get error on start that not enough memory. Or at least how calculate maximum required memory to use proper parameters?

@JohannesGaessler
Copy link
Collaborator

Yes, it's possible. As I said, the bottleneck is people actually implementing things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

No branches or pull requests