You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For BS=1, the gen phase flash_attn_vec_ext_f32 kernel is launched with a constant parallel_blocks value of 4. Check code.
However, parallel_blocks = 4 causes poor occupancy on GPU.
Consider following models. The current occupancy is far below what is achievable if parallel_blocks value is increased.
Model
num_heads
head_dim
occupancy with PB=4 on RTX 4090
achievable occupancy with optimal PB value on RTX 4090
Llama 3B
24
128
0.06
0.25
Llama 8B
32
128
0.08
0.25
Qwen 1.5B
12
128
0.03
0.25
Qwen 7B
28
128
0.07
0.25
I have a change that addresses this issue and it shows improvement in gen phase performance by up to 14%.