Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Apr 15, 2025

Carrying on @aurickq work from here #14061. Thanks to @LucasWilkinson for helping debug qo_indptr issues.

There are some performance issues in the original PR due to using BatchPrefillWithPagedKVCacheWrapper for all prefill and decode tokens. This PR separates prefill and decode tokens in V1 using the reorder_batch() functionality added for MLA, where the requests in the input_batch is reshuffled such that all decode tokens are at the front and all prefill tokens are at the back. This makes it easy to split the input/output to the attention implementation to contiguous chunks for decode and prefill.

With this new implementation FlashInfer 0.2.1.post2 is close to within the performance of FA3.

Evaluations

Evaluations on GSM8k:

export VLLM_ATTENTION_BACKEND=FLASHINFER
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.2-1B-Instruct --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=meta-llama/Llama-3.2-1B-Instruct,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3351|±  | 0.013|
|     |       |strict-match    |     5|exact_match|↑  |0.3351|±  | 0.013|

lm_eval --model vllm --model_args pretrained=Qwen/Qwen2.5-7B-Instruct --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=Qwen/Qwen2.5-7B-Instruct,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8264|±  |0.0104|
|     |       |strict-match    |     5|exact_match|↑  |0.7885|±  |0.0112|

lm_eval --model vllm --model_args pretrained=RedHatAI/QwQ-32B-FP8-dynamic,tensor_parallel_size=2 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=RedHatAI/QwQ-32B-FP8-dynamic,tensor_parallel_size=2,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4321|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.7369|±  |0.0121|

Benchmarks

Benchmarks run on H100:

Llama 8B at 1024/128 input/output tokens showing the improvement over mixed implementation and comparing FA3:

python benchmarks/benchmark_throughput.py --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 1000 --input-len 1024 --output-len 128

# V1 FA3
Throughput: 25.63 requests/s, 30776.11 total tokens/s, 3280.51 output tokens/s

# V1 Original Flashinfer (Old PR, combined prefill+decode)
Throughput: 15.51 requests/s, 18616.93 total tokens/s, 1985.48 output tokens/s

# V1 Flashinfer (This PR)
Throughput: 25.09 requests/s, 30112.70 total tokens/s, 3212.02 output tokens/s

Llama 8B at 1000/1000 input/output tokens against FA3:

python benchmarks/benchmark_throughput.py --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 1000 --input-len 1000 --output-len 1000

export VLLM_ATTENTION_BACKEND=FLASHINFER
Throughput: 5.93 requests/s, 12136.45 total tokens/s, 5931.55 output tokens/s

export VLLM_ATTENTION_BACKEND=FLASH_ATTN 
Throughput: 6.17 requests/s, 12632.79 total tokens/s, 6165.96 output tokens/s

QwQ 32B FP8-dynamic TP=2 at 1000/1000 input/output tokens against FA3:

python benchmarks/benchmark_throughput.py --model RedHatAI/QwQ-32B-FP8-dynamic --tensor-parallel-size=2 --num-prompts 1000 --input-len 1000 --output-len 1000

export VLLM_ATTENTION_BACKEND=FLASHINFER
Throughput: 4.25 requests/s, 8748.90 total tokens/s, 4247.52 output tokens/s

export VLLM_ATTENTION_BACKEND=FLASH_ATTN
Throughput: 4.25 requests/s, 8741.93 total tokens/s, 4248.12 output tokens/s

Signed-off-by: mgoin <[email protected]>
@mgoin
Copy link
Member Author

mgoin commented Apr 16, 2025

Will need to rebase on #16673

Signed-off-by: mgoin <[email protected]>
@mgoin mgoin changed the title V1 FlashInfer Attention [V1] V1 FlashInfer Attention Apr 18, 2025
@mergify
Copy link

mergify bot commented Apr 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 18, 2025
@mergify mergify bot removed the needs-rebase label Apr 18, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 18, 2025
@mgoin mgoin requested a review from LucasWilkinson April 21, 2025 21:51
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for doing this!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) April 21, 2025 22:08
@LucasWilkinson LucasWilkinson merged commit 986537f into vllm-project:main Apr 22, 2025
63 checks passed
@JaheimLee
Copy link

JaheimLee commented Apr 22, 2025

Can we use flashinfer's fp8 kv cache with this pr? Vllm now only do

will_use_fa = (
        current_platform.is_cuda()
        and not envs.is_set("VLLM_ATTENTION_BACKEND")
    ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
    from vllm.vllm_flash_attn.fa_utils import (
        flash_attn_supports_fp8)
    supported = flash_attn_supports_fp8()

@mgoin mgoin deleted the flashinfer-v1 branch April 22, 2025 20:26
@mgoin
Copy link
Member Author

mgoin commented Apr 22, 2025

@JaheimLee I tried enabling it here but flashinfer fails to compile #17005

frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: Frieda (Jingying) Huang <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants