- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[V1] V1 FlashInfer Attention #16684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] V1 FlashInfer Attention #16684
Conversation
Signed-off-by: mgoin <[email protected]>
Signed-off-by: mgoin <[email protected]>
| Will need to rebase on #16673 | 
Signed-off-by: mgoin <[email protected]>
| This pull request has merge conflicts that must be resolved before it can be | 
Signed-off-by: mgoin <[email protected]>
Signed-off-by: mgoin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for doing this!
| Can we use flashinfer's fp8 kv cache with this pr? Vllm now only do  | 
| @JaheimLee I tried enabling it here but flashinfer fails to compile #17005 | 
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]> Signed-off-by: Frieda (Jingying) Huang <[email protected]>
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Carrying on @aurickq work from here #14061. Thanks to @LucasWilkinson for helping debug qo_indptr issues.
There are some performance issues in the original PR due to using
BatchPrefillWithPagedKVCacheWrapperfor all prefill and decode tokens. This PR separates prefill and decode tokens in V1 using thereorder_batch()functionality added for MLA, where the requests in theinput_batchis reshuffled such that all decode tokens are at the front and all prefill tokens are at the back. This makes it easy to split the input/output to the attention implementation to contiguous chunks for decode and prefill.With this new implementation FlashInfer 0.2.1.post2 is close to within the performance of FA3.
Evaluations
Evaluations on GSM8k:
Benchmarks
Benchmarks run on H100:
Llama 8B at 1024/128 input/output tokens showing the improvement over mixed implementation and comparing FA3:
Llama 8B at 1000/1000 input/output tokens against FA3:
QwQ 32B FP8-dynamic TP=2 at 1000/1000 input/output tokens against FA3: