Skip to content

Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 3, 2025

Conversation

hjc4869
Copy link
Contributor

@hjc4869 hjc4869 commented Feb 22, 2025

  • Add a new option GGML_HIP_ROCWMMA_FATTN and defaults to OFF
  • Check for rocWMMA header availability when GGML_HIP_ROCWMMA_FATTN is enabled
  • Define FP16_MMA_AVAILABLE when GGML_HIP_ROCWMMA_FATTN is enabled and target is supported by rocWMMA (CDNA / RDNA3)
  • Use rocWMMA in FlashAttention kernel when possible

Related issue: #10439

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 22, 2025
@hjc4869
Copy link
Contributor Author

hjc4869 commented Feb 22, 2025

Adding @IMbackK for review

@JohannesGaessler
Copy link
Collaborator

As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged?

@hjc4869
Copy link
Contributor Author

hjc4869 commented Feb 23, 2025

As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged?

@IMbackK mentioned in https://github.com/hjc4869/llama.cpp/pull/1 that he'll take the ownership of this implementation and maintain it here. That's why I'm pinging him in this thread.

@Headcrabed
Copy link

@JohannesGaessler So we will keep this code path and @adelj88 @thamwangjun 's optimized code path together in the future?

Comment on lines 76 to 88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#else
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K;
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V;
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to do it like this and not with something like using namespace nvcuda:wmma?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to namespace alias instead of these ifdefs. Does that look good to you?

@IMbackK
Copy link
Collaborator

IMbackK commented Feb 23, 2025

I am currently traveling and will take a look at this next week, thank you!

@Beinsezii
Copy link
Contributor

Beinsezii commented Feb 25, 2025

Nice, +41% over master for pp2048 on gfx1100. Shame it doesn't affect text gen, where fa=1 is still ≥6% behind fa=0, and 22% behind radv

@Headcrabed
Copy link

Nice, +41% over master for pp2048 on gfx1100. Shame it doesn't affect text gen, where fa=1 is still ≥6% behind fa=0, and 22% behind radv

@Beinsezii Maybe we need new version of rocwmma+rocm to make it work on gfx12?

@Beinsezii
Copy link
Contributor

@hjc4869
Copy link
Contributor Author

hjc4869 commented Feb 25, 2025

gfx12 requires rocWMMA from ROCm 6.4 to compile and work. The one from latest 6.3.x release is broken. I may send another PR to enable that after ROCm 6.4 is released.

@bjj
Copy link
Contributor

bjj commented Feb 25, 2025

@hjc4869 if (NOT ${FOUND_ROCWMMA}) would need to be if (NOT FOUND_ROCWMMA) to make that test work, because if you expand it, it ends up as a string, and the string is truthy.

However, CHECK_INCLUDE_FILE_CXX does not really work on rocwmma/rocwmma.hpp because it needs an include path (can add "-I${ROCM_PATH}/include" to the check) BUT ALSO other things included first (at least stdint.h and some other ROCm headers) so right now the patch relies on skipping the fatal error, even though cmake prints a line about not finding the header.

@bjj
Copy link
Contributor

bjj commented Feb 25, 2025

With this patch, on gfx908 I see prompt performance with FA=1 almost exactly equal to FA=0 (without the patch, FA=1 is both slower and scales worse). Token generation is somewhere between unpatched FA=0 and FA=1.

Unfortunately, in some situations with long context, I get garbage output. Repro:

# create a sample input file
yes hello | head -n 10000 | tr '\012' ' ' > bighello.txt
# generate without FA (hit return at >)
./build/bin/llama-cli -ngl 99 --model ~/llamaq8.gguf -c 65536 --file bighello.txt
# generate with FA (hit return at >, interrupt whenever)
./build/bin/llama-cli -ngl 99 --model ~/llamaq8.gguf -c 65536 --file bighello.txt -fa

The model is from ollama llama3.1:8b-instruct-q8_0.

The FA=0 case (or the FA=1 unpatched case) prints something like

I'm happy to chat with you! However, I notice that your message is a long series of "hello"s. Is there something specific you'd like to talk about or ask? I'm here to help with any questions or topics you'd like to discuss.

While FA=1 patched (seemingly regardless of input content) produces:

GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG

Let me know if there are any other tests I can run.

HIP version: 6.3.42134-a9a80e791

@JohannesGaessler
Copy link
Collaborator

Let me know if there are any other tests I can run.

Check the output of test-backend-ops -o FLASH_ATTN_EXT.

@bjj
Copy link
Contributor

bjj commented Feb 25, 2025

Here's the output:

~/hjc4869-llama.cpp$ ./build/bin/test-backend-ops -o FLASH_ATTN_EXT
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
Testing 2 devices

Backend 1/2: ROCm0
  Device description: AMD Instinct MI100
  Device memory: 32752 MB (32736 MB free)

  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.129117277 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): [FLASH_ATTN_EXT] NMSE = 1.136229698 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.104486589 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): [FLASH_ATTN_EXT] NMSE = 1.089722395 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.035668693 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.031973529 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.156827757 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): [FLASH_ATTN_EXT] NMSE = 2.356175198 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.110820343 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): [FLASH_ATTN_EXT] NMSE = 1.395904017 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.136781798 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 2.869900324 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.014801274 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.018263909 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.134523564 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.113570082 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.190856731 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 3.337441815 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.115495022 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 2.006258292 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.150619848 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): [FLASH_ATTN_EXT] NMSE = 1.139639138 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=64,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): Unsupported KV type combination for head_size 256.
Only f16 is supported.
/home/bjj/hjc4869-llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh:670: fatal error
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted (core dumped)

@bjj
Copy link
Contributor

bjj commented Feb 25, 2025

Maybe this is obvious, but the failing cases above are just all of the cases where:

#if defined(GGML_HIP_ROCWMMA_FATTN)
        if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
            return;
        }
#endif // defined(GGML_HIP_ROCWMMA_FATTN)

@hjc4869
Copy link
Contributor Author

hjc4869 commented Feb 25, 2025

The crash is caused by lack of check in ggml_backend_cuda_device_supports_op, perhaps we'll need something like this (add op->src[0]->ne[1] > 8 to the condition, matching that in ggml_cuda_flash_attn_ext).

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index ebb2ccae..dc6ed6ab 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3216,6 +3216,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
+
+#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
+            return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16 && op->src[0]->ne[1] > 8;
+#elif defined(GGML_USE_HIP)
+            return false;
+#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
+
             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
Differences of test results on W7900DS between pr branch and master after above fix
diff --git a/home/david/Desktop/master.log b/home/david/Desktop/pr.log
index d4a3d4ee..47899baa 100644
--- a/home/david/Desktop/master.log
+++ b/home/david/Desktop/pr.log
@@ -196,16 +196,16 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -220,11 +220,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -244,16 +244,16 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -268,11 +268,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -284,11 +284,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -300,11 +300,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -316,11 +316,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -332,11 +332,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
@@ -348,11 +348,11 @@ Backend 1/2: ROCm0
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 
-  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0] 
+  FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0] 
   FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] 

I currently only have RDNA3 GPUs so I haven't tested the patch on CDNA. Perhaps we'll wait for @IMbackK to decide if we would enable this code path on a per device basis based on test results, or if there's something wrong with my changes that caused issues on CDNA.

@adelj88
Copy link

adelj88 commented Feb 25, 2025

A slight off-topic, but I'm wondering if we need to enable different code paths for RDNA3/4 and CDNA in the future, as well as whether we use rocWMMA in those future optimisations (we might eventually need to roll out something similar to fattn-mma-f16.cuh with custom tiles for maximum performance, as I've observed some limitations with rocWMMA to reach maximum performance).

In the scenario where the PR above has some issues on CDNA, I wouldn't mind if this PR enabled this code path specifically for RDNA3/4, at least until optimised code paths are introduced.

Comment on lines 256 to 261
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Q->ne[1] > 8 correct? In my experience even for very small batch sizes like 2 the wmma kernel can be faster (on NVIDIA).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check was indeed not needed, it seems to be an early hack in my private branches to get rid of the cols_per_block = 8 cases. I've removed them.

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 2, 2025

I currently only have RDNA3 GPUs so I haven't tested the patch on CDNA. Perhaps we'll wait for @IMbackK to decide if we would enable this code path on a per device basis based on test results, or if there's something wrong with my changes that caused issues on CDNA.

This pr certenly can only be merged once it works on cdna, as i ofc need to have the hardware to test it to maintain it, so we can not restrict it to rdna3+ wmma, luckily an older version of a simmilar patch set to this one did work fine on cdna, so should not be to hard to find where its going off the rails now, but i need to find the time for a debugging session.

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 2, 2025

A slight off-topic, but I'm wondering if we need to enable different code paths for RDNA3/4 and CDNA in the future, as well as whether we use rocWMMA in those future optimisations (we might eventually need to roll out something similar to fattn-mma-f16.cuh with custom tiles for maximum performance, as I've observed some limitations with rocWMMA to reach maximum performance).

The performance of this pr should be fairly poor compared to device roofline, as it spends a lot of time shuffling around data and not so much doing useful work, its only intended as a stop-gap.

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 2, 2025

I must preface this by saying: I have no idea what I'm doing. But I think the issue is that CDNA wave size is 64 (see https://rocm.docs.amd.com/projects/rocWMMA/en/latest/api-reference/api-reference-guide.html ). By "test driven development" I have tested this theory by building only the necessary fattn stuff with #define WARP_SIZE 64 (can't just change it globally because there are too many things that assume it is 32). That makes the test pass and generation is sane.

@bjj oh i see you already figured it out, if you dont mind you could swap the use of the define with a constexpr int warp_size = ggml_cuda_get_physical_warp_size(); in device code and something like

int device;
int warp_size;

CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;

when in host code and post a diff.
Otherwise i can do so when i find the time to check this out (hopefully next weekend.)

@IMbackK IMbackK self-assigned this Mar 2, 2025
@bjj
Copy link
Contributor

bjj commented Mar 2, 2025

@IMbackK I'm happy to try to clean up this diff, but I don't know anything about llama.cpp internals (or CUDA, or...), so I'd need some coaching. Here's how I approached it before:

  1. The flash attention kernel in fattn-wmma-f16.cuh as well as its caller must know WARP_SIZE=64... but
  2. You can't globally change WARP_SIZE, because there are a lot of other hardcoded constants that interact with WARP_SIZE and cause compile errors, like ggml_cuda_mma::tile dividing by WARP_SIZE all over the place, and then hardcoded tile sizes that trip asserts or make 1-length arrays that are assumed to be >1... but
  3. You can't just change WARP_SIZE in fattn.cu and fattn-wmma-f16.cuh because they rely on helpers in common.cuh. So I ended up making it possible to include common.cuh with a redefined WARP_SIZE, and to my surprise, that worked.
  4. The heuristics for sizing the kernel (choosing parallel_blocks, cols_per_block, etc) choose sizes that don't work (such as all cases where parallel_blocks=4) so I just modified them until the tests pass. But the test coverage may not be enough (see above re: not hitting prec != GGML_PREC_DEFAULT), and I never understood the actual limit that is being hit.

If everything that cares about WARP_SIZE can tolerate it not being a compile-time constant, I can probably make that work. I would need help fixing the heuristics for selecting kernel sizes. I see things like warp_reduce_max, though, which are templated in WARP_SIZE and fully unrolled, divides by WARP_SIZE, other loops unrolled with WARP_SIZE in their iterations, and I wonder if this actually needs to build two different WARP_SIZE variants, or at least be templated in WARP_SIZE with _case helpers.

There's one other case which may be worth investigating: The M/N/K size ends up being 16/16/16 all the time, but CDNA could do 32/32/8+ (besides 16/16/16, the current code supports a case of 32/8/16). I'm not sure if that's valuable.

My hack to use fattn on CDNA, on top of this pull request
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index f90e7c6d..72000591 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -37,7 +37,9 @@
 #define STRINGIZE_IMPL(...) #__VA_ARGS__
 #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
 
+#ifndef WARP_SIZE
 #define WARP_SIZE 32
+#endif
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 68cfc6a3..4b625532 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -2,6 +2,8 @@
 // It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
 // Long-term the WMMA code should be replaced with a dedicated Volta implementation.
 
+#define WARP_SIZE 64
+
 #include "common.cuh"
 #include "fattn-common.cuh"
 #include "fattn-wmma-f16.cuh"
@@ -64,7 +66,6 @@ static __global__ void flash_attn_ext_f16(
         NO_DEVICE_CODE;
         return;
     }
-
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
@@ -474,6 +475,10 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (4*blocks_num_pb1 < 2*nsm) {
+        //printf("want pb=4\n");
+    }
+
+    if (false && 4*blocks_num_pb1 < 2*nsm) {
         constexpr int parallel_blocks = 4;
         fattn_kernel_t fattn_kernel;
         if (logit_softcap == 0.0f) {
@@ -486,6 +491,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        //printf("pb=4 blocks_num_pb1 %d nsm %d D %d cols_per_block %d nwarps %d stride %d\n", blocks_num_pb1, nsm, D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m));
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,6 +506,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
+        //printf("pb=2 blocks_num_pb1 %d nsm %d D %d cols_per_block %d nwarps %d stride %d\n", blocks_num_pb1, nsm, D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m));
         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
         return;
     }
@@ -514,6 +521,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
     }
+    //printf("pb=1 blocks_num_pb1 %d nsm %d D %d cols_per_block %d nwarps %d stride %d\n", blocks_num_pb1, nsm, D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m));
     launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
 }
 
@@ -523,8 +531,9 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
 
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
+    //printf("prec %d Q->ne[1] %ld Q->ne[0] %ld\n", prec, Q->ne[1], Q->ne[0] );
     if (prec != GGML_PREC_DEFAULT) {
-        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+        if (true || Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
             switch (Q->ne[0]) {
                 case 64:
@@ -602,7 +611,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
     }
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
-    if (Q->ne[1] <= 32) {
+    if (true|| Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
         switch (Q->ne[0]) {
             case 64:
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index decceeb2..777980c4 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -1,3 +1,4 @@
+#define WARP_SIZE 64
 #include "common.cuh"
 #include "fattn-common.cuh"
 #include "fattn-mma-f16.cuh"
diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt
index e3762649..1e97b6c2 100644
--- a/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ggml/src/ggml-hip/CMakeLists.txt
@@ -40,9 +40,9 @@ find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
 if (GGML_HIP_ROCWMMA_FATTN)
-    CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
-    if (NOT ${FOUND_ROCWMMA})
-        message(FATAL_ERROR "rocwmma has not been found")
+    CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA "-I${ROCM_PATH}/include")
+    if (NOT FOUND_ROCWMMA)
+        #message(FATAL_ERROR "rocwmma has not been found")
     endif()
 endif()
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index e1f7e675..4830976a 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -54,7 +54,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
             std::uniform_real_distribution<float> distribution(min, max);
             auto & gen = generators[ith];
             for (size_t i = start; i < end; i++) {
-                data[i] = distribution(gen);
+                data[i] = distribution(gen); // zero
             }
         };
 
@@ -163,6 +163,7 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
                     } else {
                         GGML_ABORT("fatal error");
                     }
+                    //printf( "%ld %ld %ld %ld %9.6f\n", i0, i1, i2, i3, tv.back() );
                 }
             }
         }

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 2, 2025

  1. The flash attention kernel in fattn-wmma-f16.cuh as well as its caller must know WARP_SIZE=64... but

Inside of the kernel you can add: constexpr int warp_size = ggml_cuda_get_physical_warp_size(); and then use it instead of WARP_SIZE in all cases as ggml_cuda_get_physical_warp_size is constexpr it remains a compile time constant.

  1. You can't globally change WARP_SIZE, because there are a lot of other hardcoded constants that interact with WARP_SIZE and cause compile errors, like ggml_cuda_mma::tile dividing by WARP_SIZE all over the place, and then hardcoded tile sizes that trip asserts or make 1-length arrays that are assumed to be >1... but

I am working on chaning this, but yes we are not ready to remove the hard coded WARP_SIZE yet.

  1. You can't just change WARP_SIZE in fattn.cu and fattn-wmma-f16.cuh because they rely on helpers in common.cuh.

The helpers in common.cuh needed should all support selectable warp size via a template parameter that defaults to WARP_SIZE, you can simply pass in the warp size you get from ggml_cuda_get_physical_warp_size() like warp_reduce_sum<warp_size>() if a helper function is missing this i can help you to add it.

  1. The heuristics for sizing the kernel (choosing parallel_blocks, cols_per_block, etc) choose sizes that don't work (such as all cases where parallel_blocks=4) so I just modified them until the tests pass. But the test coverage may not be enough (see above re: not hitting prec != GGML_PREC_DEFAULT), and I never understood the actual limit that is being hit.

i cant tell you what the problem here is without takeing the time to look at the code, but it is likely that this is simply ever other size creating a problem size not devisable into 64 wide warps.

If everything that cares about WARP_SIZE can tolerate it not being a compile-time constant, I can probably make that work. I would need help fixing the heuristics for selecting kernel sizes. I see things like warp_reduce_max, though, which are templated in WARP_SIZE and fully unrolled, divides by WARP_SIZE, other loops unrolled with WARP_SIZE in their iterations, and I wonder if this actually needs to build two different WARP_SIZE variants, or at least be templated in WARP_SIZE with _case helpers.

as mentioned, in device code (ie the kernel: __global__ and __device__ functions) the warp size you get from ggml_cuda_get_physical_warp_size is a compile time constant and can be used as sutch, in host code (all other functions) we dont know the warp size at compile time. Here you need to use ggml_cuda_info().devices[device].warp_size, which is only known at runtime. In cases where it is important that a compile time constant is available in host code, you will have to translate it into a compile time constant by doing something like this:

switch(ggml_cuda_info().devices[device].warp_size) {
    case 64:
    constexpr int warp_size = 64;
    ...
    break;
    case 32:
    default:
    constexpr int warp_size = 32;
    ....
    break;
}

@hjc4869
Copy link
Contributor Author

hjc4869 commented Mar 3, 2025

I currently only have RDNA3 GPUs so I haven't tested the patch on CDNA. Perhaps we'll wait for @IMbackK to decide if we would enable this code path on a per device basis based on test results, or if there's something wrong with my changes that caused issues on CDNA.

This pr certenly can only be merged once it works on cdna, as i ofc need to have the hardware to test it to maintain it, so we can not restrict it to rdna3+ wmma, luckily an older version of a simmilar patch set to this one did work fine on cdna, so should not be to hard to find where its going off the rails now, but i need to find the time for a debugging session.

Should we move the discussion & review process to the new PR that have included proper CDNA support, or merge the two changes separately?

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 3, 2025

@hjc4869 i would prefer to consolidate the discussion here in this pr by you pulling the changes by @bjj into your branch backing this pr, as this will keep master bisectable on cdna once it is squashed and merged, which merging separately would not.

@bjj please close the other pr when @hjc4869 has done so.

@hjc4869
Copy link
Contributor Author

hjc4869 commented Mar 3, 2025

Pulled the commits in that PR, and checked the code is passing tests on RDNA3. Please review and propose changes if needed.

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good from a functional perspective. I will profile it and take a crack at cdna optimization once its merged. rdna3 performance should be as good as you can expect from this path.

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an occupancy of 1 on cdna here is especially bad on cdna as its caches are quite slow, but i will do a performance examination of this whole path at a later date.

@hjc4869 hjc4869 requested a review from IMbackK March 3, 2025 14:25
Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now! i also confirm correctness and improved performance in pp and bs >1 cdna

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 3, 2025

Maybe someone with a volta gpu can confirm that it is still fine, this is the nv gpu that shares this code path. @ggerganov i understand you have one?

@ggerganov
Copy link
Member

@IMbackK This branch worked successfully with V100:

cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native
./bin/test-backend-ops 

Btw, as a collaborator, you can push a temp branch in this repository and trigger the ggml-ci on it. The ggml-4 runner is a V100 machine:

image

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 3, 2025

@IMbackK This branch worked successfully with V100:

cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native
./bin/test-backend-ops 

Btw, as a collaborator, you can push a temp branch in this repository and trigger the ggml-ci on it. The ggml-4 runner is a V100 machine:

I see, Thank you!

@IMbackK IMbackK merged commit becade5 into ggml-org:master Mar 3, 2025
47 checks passed
@hjc4869 hjc4869 deleted the pr branch March 4, 2025 02:05
@Headcrabed
Copy link

Btw why we disable GGML_HIP_ROCWMMA_FATTN by default?

@IMbackK
Copy link
Collaborator

IMbackK commented Mar 4, 2025

@JohannesGaessler did not want the default dependency on rocwmma. we can maybe enable it by default when rocwmma is detected at build time (as rocwmma is header only) but lets leave it be for now.

@Headcrabed
Copy link

@JohannesGaessler did not want the default dependency on rocwmma. we can maybe enable it by default when rocwmma is detected at build time (as rocwmma is header only) but lets leave it be for now.

So we'll have to wait for an improved implementation(maybe not based on rocwmma?) for it to be enabled by default?

@hjc4869
Copy link
Contributor Author

hjc4869 commented Mar 4, 2025

I think we could enable them in GitHub Actions workflow without the concern of introducing default dependency, and users can still easily access the optimized builds by downloading the prebuilt binaries.

Enable rocWMMA in CI and document how to enable them in build: #12179

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, I was busy with my PhD. This looks good to me.

mglambda pushed a commit to mglambda/llama.cpp that referenced this pull request Mar 8, 2025
…rg#12032)

Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16

---

Signed-off-by: Carl Klemm <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Ben Jackson <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Mar 19, 2025
…rg#12032)

Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16

---

Signed-off-by: Carl Klemm <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Ben Jackson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants