Skip to content

Conversation

njhill
Copy link
Member

@njhill njhill commented May 1, 2025

This is a follow-on from #15977.

A new --api-server-count arg to vllm serve can be used to specify an arbitrary number of API servers to run. When used in conjunction with --data-parallel-size there's all-to-all zmq-based communication between API servers and data parallel engines.

It works with multi-node as described in #15977. All of the API servers run on the head node.

A separate "coordinator" process is now used for DP>1. This is responsible for ensuring that the engines run in tandem, and for publishing real-time request count information (and later likely other engine state info) back to the api server(s) for load balancing purposes.

image

More design discussion: https://docs.google.com/document/d/10jhCNxJYvsUhtMtiMAaW2MxU5LU8HVje2pGDnj49gH4/edit?tab=t.0

Performance now scales much better with DP size. Observe TTFT in particular below.

Benchmark with 2xA100, llama-3.2-1B, share-gpt with request rate 120 req/sec:

DP=2 before

============ Serving Benchmark Result ============
Successful requests:                     10000     
Benchmark duration (s):                  130.74    
Total input tokens:                      2206428   
Total generated tokens:                  1994815   
Request throughput (req/s):              76.49     
Output token throughput (tok/s):         15258.46  
Total Token throughput (tok/s):          32135.56  
---------------Time to First Token----------------
Mean TTFT (ms):                          13176.40  
Median TTFT (ms):                        13953.03  
P99 TTFT (ms):                           26842.02  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.83     
Median TPOT (ms):                        22.28     
P99 TPOT (ms):                           36.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.19     
Median ITL (ms):                         21.98     
P99 ITL (ms):                            81.11     
==================================================

DP=2 with --api-server-count=2

============ Serving Benchmark Result ============
Successful requests:                     10000     
Benchmark duration (s):                  116.84    
Total input tokens:                      2206428   
Total generated tokens:                  1994815   
Request throughput (req/s):              85.59     
Output token throughput (tok/s):         17073.43  
Total Token throughput (tok/s):          35958.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          67.54     
Median TTFT (ms):                        60.81     
P99 TTFT (ms):                           329.10    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.90     
Median TPOT (ms):                        24.13     
P99 TPOT (ms):                           36.89     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.62     
Median ITL (ms):                         22.14     
P99 ITL (ms):                            51.75     
==================================================

This is working functionally but there are still a number of tasks remaining:

  • Initial benchmark results for DP=1 and multiple API servers are disappointing - I am looking into why this currently hurts ITL and throughput slightly (though TTFT slightly improves).
  • (medium) Multiple api servers don't currently work properly with metrics publishing/logging. I have discussed this with @markmc but it needs a bit more work. @kouroshHakha is helping to look at this, I will add some more notes below.
  • (small) The multi-modal embeddings cache currently won't work with DP and/or mutli-API so will need to be auto-disabled when dp > 1 and/or api-server-count > 1. Hopefully the scale-out should hide the performance downsides to that however (discussed this with @ywang96 and @DarkLight1337).
  • (small) When there are many API servers, a lot of the startup logs are duplicated. We probably want to suppress some of these.
  • (tbd) Need to look into implications for LoRA adapter loading.
  • (medium) Some more work on error handling and clean shutdown with the new process topologies.
  • (medium) Full test coverage of the various permutations.

Follow-on work (not for this PR):

  • Rework how the multi-modal feature cache is implemented to make it compatible with the any-to-any process architecture.

njhill added 19 commits April 4, 2025 17:04
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
…-engines

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/v1/engine/core_client.py
#	vllm/v1/utils.py
Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/config.py
#	vllm/engine/arg_utils.py
#	vllm/v1/engine/core.py
#	vllm/v1/engine/core_client.py
Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/v1/engine/core.py
#	vllm/v1/engine/core_client.py
…-engines

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/config.py
#	vllm/v1/engine/core.py
Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/v1/engine/core_client.py
#	vllm/v1/utils.py
…-engines

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/v1/engine/core.py
#	vllm/v1/engine/core_client.py
Copy link

github-actions bot commented May 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Avoid exception but still needs more work to be functional with multiple api server procs.

Signed-off-by: Nick Hill <[email protected]>
njhill added 2 commits May 5, 2025 14:26
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
@yinghai
Copy link
Contributor

yinghai commented May 6, 2025

How does run_rpc work if we want to bcast this to each engine and run it exactly once? How to guarantee that each engine core runs it in lock step if we want?

@yinghai
Copy link
Contributor

yinghai commented May 6, 2025

There isn't a lot of work in apiserver that needs multiprocessing right? It's mostly async_llm, most specifically MM data handling that needs scale out?

@njhill
Copy link
Member Author

njhill commented May 29, 2025

I think all the CI issues are fixed and remaining failures should be unrelated, we should let it finish though.

Copy link

mergify bot commented May 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 29, 2025
Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	tests/v1/engine/test_engine_core.py
#	vllm/v1/engine/core.py
#	vllm/v1/engine/core_client.py
@mergify mergify bot removed the needs-rebase label May 29, 2025
@simon-mo simon-mo enabled auto-merge (squash) May 29, 2025 22:58
@simon-mo simon-mo merged commit 2dbe8c0 into vllm-project:main May 30, 2025
92 of 94 checks passed
@njhill njhill deleted the all-to-all branch May 30, 2025 18:43
@lgeiger
Copy link
Contributor

lgeiger commented May 31, 2025

I am now seeing the following warnings on main when running some test e.g.:

pytest tests/v1/engine/test_async_llm.py::test_load -s
ERROR 05-31 01:04:14 [prometheus.py:77] Error during metrics cleanup: expected str, bytes or os.PathLike object, not NoneType

This warning is thrown during the prometheus cleanup though not sure where this is coming from exactly.

@lgeiger
Copy link
Contributor

lgeiger commented May 31, 2025

Looks like the error comes from this prometheus function

@njhill
Copy link
Member Author

njhill commented May 31, 2025

Thanks @lgeiger. I think the message is harmless but I'll fix this.

njhill added a commit to njhill/vllm that referenced this pull request May 31, 2025
Introduced in vllm-project#17546. We should only call mark_process_dead when we're using prometheus multiprocessing mode (with > 1 API servers).

Signed-off-by: Nick Hill <[email protected]>
@njhill
Copy link
Member Author

njhill commented May 31, 2025

@lgeiger fixed in #18992.

@chaunceyjiang
Copy link
Collaborator

chaunceyjiang commented Jun 6, 2025

Hi, @njhill When I use this PR, I encounter an issue where all of my vLLM instances hang.

node1 :
 VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1  --data-parallel-size 32 --data-parallel-size-local 8 --data-parallel-address 10.254.20.30 --data-parallel-rpc-port 5555 --enable-expert-parallel --tensor-parallel-size 1  --data-parallel-start-rank 0
node2:
 VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1  --data-parallel-size 32 --data-parallel-size-local 8 --data-parallel-address 10.254.20.30 --data-parallel-rpc-port 5555 --enable-expert-parallel --tensor-parallel-size 1 --headless --data-parallel-start-rank 8
node3
 VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1  --data-parallel-size 32 --data-parallel-size-local 8 --data-parallel-address 10.254.20.30 --data-parallel-rpc-port 5555 --enable-expert-parallel --tensor-parallel-size 1 --headless --data-parallel-start-rank 16
node4:
 VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1  --data-parallel-size 32 --data-parallel-size-local 8 --data-parallel-address 10.254.20.30 --data-parallel-rpc-port 5555 --enable-expert-parallel --tensor-parallel-size 1 --headless --data-parallel-start-rank 24

@njhill
Copy link
Member Author

njhill commented Jun 6, 2025

@chaunceyjiang could you explain what you mean by "when you use this PR"? Do you just mean when you use the latest from main? I don't see that you're setting --api-server-count?

In any case could you open a new issue if you're still having a problem, with more detail (logs, exact version/commit being used, etc.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build frontend ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.