Skip to content

Conversation

KuntaiDu
Copy link
Collaborator

@KuntaiDu KuntaiDu commented Jul 6, 2024

This is a follow-up PR for #5557 .

Goal: implement disaggregated prefilling by launching 2 vllm instances (one for prefilling, one for decoding), and forward the KV cache from prefilling instance to decoding instance.

A rough roadmap:

  • Benchmark the idealized version of disaggregated prefilling (idealized in terms of the KV cache transfer can be done immediately).
  • Implement API calls in vllm to import / export KV cache
  • Implement an agent that can transfer KV cache between prefilling and decoding instance
  • Implement end-to-end prototype
  • Benchmark and improve the performance
  • Beta test (ongoing)
  • Support different tp/pp between prefill and decode instance

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@KuntaiDu KuntaiDu marked this pull request as draft July 6, 2024 07:44
@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Jul 7, 2024

An example of disaggregated prefill can be much better than chunked prefill:
image

  • model: llama70B fp8
  • device: 8xH100
  • workload: QPS 4, input tokens 2048, output tokens 11
  • 3 approaches
    • chunked prefill — tp8
    • chunked prefill — 2x tp4
    • disaggregated prefill — 1x tp4 prefill, 1x tp4 decode

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Jul 9, 2024

Summary of the measurement insight:

  • For long context (long: #input token >=1k when #output token=128), we should use disaggregated prefilling instead of chunked prefill.
  • The maximum overhead of disaggregated prefilling (caused by KV cache transfer) is 40ms. An ideal implementation should have less than 10 ms overhead.

@KuntaiDu
Copy link
Collaborator Author

A back-to-back comparison between chunked prefill and disaggregated prefill:

  • Input length: 2048
  • Output length: 150
  • Dataset: sonnet
  • Num of prompts: 400
  • QPS: 2,4,6,8
  • Methods:
    • chunked prefill: 2 vllm instances, tp4, with chunked prefill enabled, 2 instances share the workload in a round-robin manner
    • disagg prefill: 2 vllm instances, tp4, one for prefill and one for decode
      • My current implementation can let us get the first token before implementation overheads (like KV transfer and waiting until decode instance is ready to receive generated KV cache) happen by fetching the first token from the prefill instance, but for benchmarking’s sake I count these overheads into TTFT by fetching the first token from decode instance.
  • Results:
    • lower median TTFT and median ITL when QPS<=6 (at QPS=8 the decode instance is backlogging)
    • worse p99 ITL ---- Sometimes the KV cache transfer may fail (this appears rarely and I am not sure why for now), forcing the decoding instance to redo the prefill by itself, which makes ITL worse.

image
image
image
image

@wjj19950828
Copy link

A back-to-back comparison between chunked prefill and disaggregated prefill:

  • Input length: 2048

  • Output length: 150

  • Dataset: sonnet

  • Num of prompts: 400

  • QPS: 2,4,6,8

  • Methods:

    • chunked prefill: 2 vllm instances, tp4, with chunked prefill enabled, 2 instances share the workload in a round-robin manner

    • disagg prefill: 2 vllm instances, tp4, one for prefill and one for decode

      • My current implementation can let us get the first token before implementation overheads (like KV transfer and waiting until decode instance is ready to receive generated KV cache) happen by fetching the first token from the prefill instance, but for benchmarking’s sake I count these overheads into TTFT by fetching the first token from decode instance.
  • Results:

    • lower median TTFT and median ITL when QPS<=6 (at QPS=8 the decode instance is backlogging)
    • worse p99 ITL ---- Sometimes the KV cache transfer may fail (this appears rarely and I am not sure why for now), forcing the decoding instance to redo the prefill by itself, which makes ITL worse.

image image image image

@KuntaiDu Why are some concurrency TTFT indicators reaching 5000-10000ms? Is it pending?

But I feel that under 70b tp4, the qps on the H100 card should not be so low

@KuntaiDu
Copy link
Collaborator Author

A back-to-back comparison between chunked prefill and disaggregated prefill:

  • Input length: 2048

  • Output length: 150

  • Dataset: sonnet

  • Num of prompts: 400

  • QPS: 2,4,6,8

  • Methods:

    • chunked prefill: 2 vllm instances, tp4, with chunked prefill enabled, 2 instances share the workload in a round-robin manner

    • disagg prefill: 2 vllm instances, tp4, one for prefill and one for decode

      • My current implementation can let us get the first token before implementation overheads (like KV transfer and waiting until decode instance is ready to receive generated KV cache) happen by fetching the first token from the prefill instance, but for benchmarking’s sake I count these overheads into TTFT by fetching the first token from decode instance.
  • Results:

    • lower median TTFT and median ITL when QPS<=6 (at QPS=8 the decode instance is backlogging)
    • worse p99 ITL ---- Sometimes the KV cache transfer may fail (this appears rarely and I am not sure why for now), forcing the decoding instance to redo the prefill by itself, which makes ITL worse.

image image image image

@KuntaiDu Why are some concurrency TTFT indicators reaching 5000-10000ms? Is it pending?

But I feel that under 70b tp4, the qps on the H100 card should not be so low

Yes, the requests are pending and that's why the TTFT is high. As for the QPS, let me double check.

@LesLieZC0324
Copy link

It is a nice work! However, I meet some problems in actual use.
In this PR, chunked prefill with tp4 used round_robin_proxy.sh to forward polling requests for port 8000 to ports 8100 and 8200. However, when I used it, once the socat process is started and starts listening on port 8000, it will continue to receive all subsequent connections without calling the get_next_port function again to select another port.
Looking forward to your reply!

@Yang-x-Zhao
Copy link

A friendly reminder here: chunked prefill feature needs a parameter --max-num-batched-tokens. In the original post of chunked prefill, the author found a fact that instead of the default 512, using 2048 on A100 gave a better result.

For Disaggregated prefill, it might be interesting to set this parameter differently on prefill and decode instances. Since prefill stage is compute bound and decode stage is memory bound, my intuition is to set a small max batched tokens to prefill and a large max batched tokens to decode

For instance, when I was benchmarking llama3-8B with prefilll tp2 A100 and decode tp2 A100, I have set prefill instance to --max-num-batched-tokens 4096 and set decode instance to --max-num-batched-tokens 32768. With these settings, I have achieved a slightly better result.
image

@wjj19950828
Copy link

@KuntaiDu @MazarineGlacier Have you ever tested Disaggregated prefill vs normal version (without chunk prefill)? I tested P2D2 vs tp4 and found no benefit. Is this normal?

@Yang-x-Zhao
Copy link

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.

I am not certain about TPOT/ITL though, that depends on the real batched tokens.

After all, it is possible that no benefit is found.

@wjj19950828
Copy link

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.

I am not certain about TPOT/ITL though, that depends on the real batched tokens.

After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

@Yang-x-Zhao
Copy link

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

This paper might answer your question: https://arxiv.org/html/2401.11181v1. In this paper, when workload is too large on both prefill and decode, disaggregate prefill failed.

@ChuanhongLi
Copy link

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.
I am not certain about TPOT/ITL though, that depends on the real batched tokens.
After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

Me too. But I did it on 4090. Maybe the overhead is too high without nvlink.

@wjj19950828
Copy link

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.
I am not certain about TPOT/ITL though, that depends on the real batched tokens.
After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

Me too. But I did it on 4090. Maybe the overhead is too high without nvlink.

Yes, I am studying the kv cache transmission overhead here. I see the author said it is about 30ms, which is definitely unacceptable.

@LesLieZC0324
Copy link

LesLieZC0324 commented Aug 23, 2024

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.
I am not certain about TPOT/ITL though, that depends on the real batched tokens.
After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

Me too. But I did it on 4090. Maybe the overhead is too high without nvlink.

Yes, I am studying the kv cache transmission overhead here. I see the author said it is about 30ms, which is definitely unacceptable.

In our scenario, there is a barrier in the KV Cache transmission without nvlink, which makes TTFT and ITL (which includes TTFT) increase.

@wjj19950828
Copy link

wjj19950828 commented Aug 25, 2024

@KuntaiDu In fact, I don't think it is necessary to do tolist operation for hash calculation, as follows:
input_tokens_tuple = tuple(model_input.input_tokens.tolist())
This will cause the d2h copy to be time-consuming, especially when input_ids is very long.Do you have any other recommended methods for calculating hash? Thanks~

@KuntaiDu
Copy link
Collaborator Author

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.
I am not certain about TPOT/ITL though, that depends on the real batched tokens.
After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

Me too. But I did it on 4090. Maybe the overhead is too high without nvlink.

The NVLink or Infinityband is a must for disaggregated prefilling in order for it to be better than chunked prefill. The time delta allowed for data transfer is less than 50ms.

@KuntaiDu
Copy link
Collaborator Author

I tested P2D2 vs tp4 and found no benefit. Is this normal?

In your case, disaggregate prefill will behave worse in TTFT. This is because default VLLM prioritizes prefill and tp4 has (less than) twice the compute capability during prefill than P2D2.
I am not certain about TPOT/ITL though, that depends on the real batched tokens.
After all, it is possible that no benefit is found.

In our scenario, disaggregate prefill on TTFT/TPOP/ITL is much worse than TP4. I don’t know where the problem is, so I wonder in which scenarios disaggregate prefill will be beneficial.

Me too. But I did it on 4090. Maybe the overhead is too high without nvlink.

Yes, I am studying the kv cache transmission overhead here. I see the author said it is about 30ms, which is definitely unacceptable.

There are several performance optimization opportunities I am not exploring yet. In the current implementation, the first token has been sampled twice, and the model input data is constructed twice. These overheads can be bypassed by engineering. and will be optimized after the implementation is stable.

I am now working on an upcoming vllm performance post, will circle back to this right after that.

@KuntaiDu
Copy link
Collaborator Author

It is a nice work! However, I meet some problems in actual use. In this PR, chunked prefill with tp4 used round_robin_proxy.sh to forward polling requests for port 8000 to ports 8100 and 8200. However, when I used it, once the socat process is started and starts listening on port 8000, it will continue to receive all subsequent connections without calling the get_next_port function again to select another port. Looking forward to your reply!

Oh let me double check and fix it.

@gursimar
Copy link

ut I did it on 4090. Maybe the overhead is to

I'm interested to build upon this implementation for faster kv cache transfer

  1. What is the current method of kv cache transfer is you are not using nvlink?
  2. Why can't we simply use torch.distributed.isend and torch.distributed.irecv to do the transfer using nccl?

@Luis-xu
Copy link

Luis-xu commented Sep 6, 2024

@KuntaiDu I am very interested in this work. I found that currently only KV cache transmission with flash-attn backend is supported. Is there any plan to support xformers and flashinfer?

@wenqf11
Copy link

wenqf11 commented Sep 9, 2024

Thanks for your work, I got a question, can I start like 6 prefill instances and 2 decode instances on 8 GPUs and how ?

@WhatGhost
Copy link

WhatGhost commented Sep 10, 2024

@KuntaiDu hi, i just want to know what's the relationship between your work and the pr#2809 . it seems you both want to implement something like "disaggregated prefilling" to separate profile and generation phases.

I wonder what the difference.

Looking forward to your reply!
Thanks!

@junna2016
Copy link

I meet a problem when I concat k and v tensor together for each layer to send and recv, Llama-2-7B output results will be random and uncorrect. Why does concating k and v lead to this result?

@KuntaiDu
Copy link
Collaborator Author

Close this PR now (I did a large-scale refactor and it is now in #8498 )

@KuntaiDu
Copy link
Collaborator Author

Thanks for your work, I got a question, can I start like 6 prefill instances and 2 decode instances on 8 GPUs and how ?

Not implemented directly in the new PR but yeah, it is on the roadmap and will soon be implemented.

@KuntaiDu
Copy link
Collaborator Author

@KuntaiDu hi, i just want to know what's the relationship between your work and the pr#2809 . it seems you both want to implement something like "disaggregated prefilling" to separate profile and generation phases.

I wonder what the difference.

Looking forward to your reply! Thanks!

Not sure about that thread. I skimmed their code and my implementation is lighter weighted and overhead is tolerable (though definitely will be larger).

@KuntaiDu
Copy link
Collaborator Author

Deprecating this PR in favor of #8498 .

@KuntaiDu KuntaiDu closed this Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.