Skip to content

Conversation

Abatom
Copy link
Contributor

@Abatom Abatom commented Mar 31, 2025

An implementation of XpYd with dynamic scaling based on point-to-point communication, partly inspired by Dynamo.

DeepSeek R1 On H20 with 1P2D and V0

In the deepseek-R1 inference scenario with 1k input and 1k output tokens, where the three H20 machines (1P2D) have a TTFT of around 2 seconds and TPOT ≤ 100ms, the throughput improvement compared to deploying vllm on three single H20 machine is 115%(2396/3/370-1).

Configuration TTFT avg(ms) TPOT avg(ms) Input Throughput (tokens/s) Output Throughput (tokens/s)
Single 2022 52 370 370
1P2D 2037 90 2396 2396

Architecture diagram

PD分离架构图

Explanations:

  • As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size.
  • The control flow goes through ZMQ, while the data flow goes through NCCL.
  • To support dynamic scaling (expansion and contraction) of instances with PD (Prefetching and Decoding) separation. This means that adding or removing P/D instances does not require a full system restart.
  • Currently, the proxy uses a random method to achieve load balancing. In the future, it can adopt a routing approach based on prefix caching to maximize cache hit rates.
  • For the P instance, the --max-num-seqs should be set to a smaller value. In my scenario, I set it to 5 to avoid filling up the buffer of the D instance, which would otherwise cause the D instance to recompute the prefill. To address the issue of setting --max-num-seqs to a small value, I will implement a local memory pool to handle the sudden increase in kvcache to avoid recompute prefill.

TODO

  • Tensor Memory Pool, to handle the sudden increase in kvcache to avoid that D recomputes prefill.
  • Better support for V1, we will provide detailed support for V1 once the P/D API for V1 becomes stable.

Install vLLM

cd /home
wget https://vllm-wheels.s3.us-west-2.amazonaws.com/d43f914d42dc00a59ca8b6d26363cf02b3b898b2/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
export VLLM_PRECOMPILED_WHEEL_LOCATION=/home/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
cd vllm
pip install -e . -v

Environment configuration

Delete this line in vllm/env_override.py

os.environ['NCCL_CUMEM_ENABLE'] = '0'

It is highly recommended to set NCCL_CUMEM_ENABLE=1, allowing NCCL to use CUDA Unified Memory (CUmem) for communication. This can reduce the overhead of GPU memory copying and improve the efficiency of multi-GPU or cross-node communication.

export NCCL_CUMEM_ENABLE=1

In --kv-transfer-config, the sending type includes tree mutually exclusive options: PUT, GET, PUT_ASYNC.

How to run 1P2D with V0? (Stable)

export VLLM_USE_V1=0

Node 1 (IP:1.1.1.1)

python3 disagg_prefill_proxy_xpyd.py

Node 1 (IP:1.1.1.1)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20001 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 5 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pConnector","kv_role":"kv_producer","kv_buffer_size":"1e9","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20001","send_type":"PUT_ASYNC"}}'

Node 2 (IP:2.2.2.2)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20002 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 256 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pConnector","kv_role":"kv_consumer","kv_buffer_size":"1e10","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20002","send_type":"PUT_ASYNC"}}'

Node 3 (IP:3.3.3.3)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20003 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 256 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pConnector","kv_role":"kv_consumer","kv_buffer_size":"1e10","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC"}}'

How to run 1P2D with V1? (Unstable)

export VLLM_USE_V1=1

Node 1 (IP:1.1.1.1)

python3 disagg_prefill_proxy_xpyd.py

Node 1 (IP:1.1.1.1)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20001 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 5 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e9","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20001","send_type":"PUT_ASYNC"}}'

Node 2 (IP:2.2.2.2)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20002 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 256 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"1e10","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20002","send_type":"PUT_ASYNC"}}'

Node 3 (IP:3.3.3.3)

vllm serve Meta-Llama-3.1-8B-Instruct \
    --host 0.0.0.0 \
    --port 20003 \
    --tensor-parallel-size 8 \
    --seed 1024 \
    --served-model-name Llama \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --max-num-seqs 256 \
    --trust-remote-code \
    --gpu-memory-utilization 0.9 \
    --kv-transfer-config \
    '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"1e10","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"1.1.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC"}}'

Request

curl -X POST -s http://localhost:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "base_model",
"prompt": "San Francisco is a",
"max_tokens": 100,
"temperature": 0
}'

Signed-off-by: Abatom <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 31, 2025
Abatom added 7 commits March 31, 2025 21:33
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
Signed-off-by: Abatom <[email protected]>
@Abatom Abatom requested a review from maobaolong April 1, 2025 12:42
@robertgshaw2-redhat robertgshaw2-redhat requested review from robertgshaw2-redhat and removed request for maobaolong April 2, 2025 18:20
@Abatom
Copy link
Contributor Author

Abatom commented Apr 4, 2025

@robertgshaw2-redhat ping!

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Apr 4, 2025

@robertgshaw2-redhat ping!

My afternoon is blocked for me to focus on this PR. Thanks!

@robertgshaw2-redhat
Copy link
Collaborator

Are you okay with me deleting the proxy in a follow up?

logger = logging.getLogger(__name__)


class P2pNcclPipe:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not inherit from KVPipeBase - I think we should leverage the base class to make sure the implementations remain consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to add two parameters, tensor_id: str = "" and remote_address: Optional[str] = None, to both the send_tensor and recv_tensor functions in KVPipeBase to make sure the implementations remain consistent.
Is this approach acceptable?

if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, can we leverage ipc?

In general, using zmq sockets with pickle is insecure (see below)

With TCP, any arbitrary user with access to the route can cause remote code execution.

For this prototyping stage, we should only have ipc. When we move to production deployments, we can turn on tpc but with security caveats.

@robertgshaw2-redhat
Copy link
Collaborator

Thanks @Abatom - generally looks good. Key concerns:

  • Security related to zmq + pickle + tpc
  • The new class does not conform to the Base class heirarchy

@Abatom
Copy link
Contributor Author

Abatom commented Apr 5, 2025

Are you okay with me deleting the proxy in a follow up?

No problem.

@Abatom
Copy link
Contributor Author

Abatom commented Apr 5, 2025

Thanks @Abatom - generally looks good. Key concerns:

  • Security related to zmq + pickle + tpc
  • The new class does not conform to the Base class heirarchy

Thank you very much for your code review.I will complete the revisions as soon as possible.

@WXFMAV
Copy link

WXFMAV commented Apr 28, 2025

To address the issue of setting --max-num-seqs to a small value, I will implement a local memory pool to handle the sudden increase in kvcache to avoid recompute prefill.

我这几天尝试之后观察到的现象是,prefill的吞吐量确实增加非常多,但是decoder的吞吐量和v1相比并没有什么变化。 不知道是不是decoder需要特殊配置之类的?在我的场景中,每个prompt大约1k长度,同时响应的内容大约2~3k长度,是一个比较heavy的token generation的任务。

@Abatom
Copy link
Contributor Author

Abatom commented Apr 28, 2025

@WXFMAV, What model, what GPU? It's likely that the temporary GPU memory allocation is too small. You can try reducing the kv_buffer_size or decreasing the --gpu-memory-utilization. Additionally, increasing the number of D instances would help. The fact that D is showing Pending clearly indicates that it is overloaded.

@WXFMAV
Copy link

WXFMAV commented Apr 28, 2025

@WXFMAV, What model, what GPU? It's likely that the temporary GPU memory allocation is too small. You can try reducing the kv_buffer_size or decreasing the --gpu-memory-utilization. Additionally, increasing the number of D instances would help. The fact that D is showing Pending clearly indicates that it is overloaded.

It is Nvidia A100 gpu cards 80GB, Qwen 7B model, I compare this method with the origin vllm v1 baseline method, both two experiments config four gpu cards with the same request qps of 10 qps. The baseline method works well, but the xpyd method will soon increase the pending reqs and consume the KV-Cache rapidly, indictating that the decoders are overload. So I think the prefill and decoder seprated config does not improve the decoder throughput in my scenarios. That puzzles me. And the gpu utilization of 70%, 60% were all be tested but can not lead to improvements.

@Abatom
Copy link
Contributor Author

Abatom commented Apr 28, 2025

@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.

@WXFMAV
Copy link

WXFMAV commented Apr 28, 2025

@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV好的,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更少,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为这会影响批量大小。

Fine, THANKS, let me try it!

@cyber-pioneer
Copy link

When is this PR expected to be merged?

@Abatom
Copy link
Contributor Author

Abatom commented Apr 29, 2025

When is this PR expected to be merged?

After the support for v1 is perfected,it should be possible to merge.

@WXFMAV
Copy link

WXFMAV commented Apr 29, 2025

@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV好的,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更少,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为这会影响批量大小。@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV OK, I think there are still some parameters that haven't been properly configured. The adjustable parameters include the ratio of P instances to D instances(try 1:7), reducing kv_buffer_size to 1e10 or less, setting export NCCL_CUMEM_ENABLE=1, and avoiding reducing --gpu-memory-utilization as much as possible, since that would affect the batch size.@WXFMAV好的,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更少,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为这会影响批量大小。@WXFMAVOK,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更小,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为那会影响批量大小。@WXFMAVOK,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更小,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为那会影响批量大小。@WXFMAV好的,我认为还有一些参数没有正确配置。可调整的参数包括P实例与D实例的比率(尝试1:7),将kv_buffer_size减少到1e10或更少,设置export NCCL_CUMEM_ENABLE=1,并尽可能避免减少--gpu-memory-utilization,因为这会影响批量大小。

Fine, THANKS, let me try it!Fine, THANKS, let me try it! 好的,谢谢,让我试试!

我尝试了这个配置, 但是看起来效果仍然不太好, 现在是1P7D, 10qps的请求, 到了1.6W个累计请求之后,会有大量的 Failed toreceive all KVS.
vllm-v1可以用4卡吞吐10qps的请求, 1P7D用了8张卡却承受不了10-qps请求

使用的模型是: Qwen2-7B, 数据集合是一个输入平均为1000, 输出为2000~3000的任务, 执行前期1000个prompts不会报错, 到了1.6W个prompts左右开始会报错. 其中GPU卡配置如下

GPU 0: NVIDIA A100-SXM4-80GB (UUID: )
GPU 1: NVIDIA A100-SXM4-80GB (UUID: )
GPU 2: NVIDIA A100-SXM4-80GB (UUID: )
GPU 3: NVIDIA A100-SXM4-80GB (UUID: )
GPU 4: NVIDIA A100-SXM4-80GB (UUID: )
GPU 5: NVIDIA A100-SXM4-80GB (UUID: )
GPU 6: NVIDIA A100-SXM4-80GB (UUID: )
GPU 7: NVIDIA A100-SXM4-80GB (UUID: )

具体的prefill 和 decoder的启动命令如下

detaillog=$result_folder"/detail_log"
device_p1=0
kvconfig_p1='{"kv_buffer_size":2e10,"kv_connector":"P2pConnector","kv_role":"kv_producer","kv_port":"'${port_kv_train[$device_p1]}'","kv_connector_extra_config":{"proxy_ip":"0.0.0.0","proxy_port":"'${port_proxy_inner[$device_p1]}'","http_port":"'${port_server[$device_p1]}'"}}'
  
VLLM_RPC_TIMEOUT=1800000 \
VLLM_TORCH_PROFILER_DIR=$result_folder \
VLLM_USE_V1=0 \
NCCL_CUMEM_ENABLE=1 \
CUDA_VISIBLE_DEVICES=$device_p1 \
vllm serve $model \
  --chat-template template.jinja \
  --port ${port_server[$device_p1]} \
  --dtype float16 \
  --trust-remote-code \
  --max-model-len 8192 \
  --max-seq-len-to-capture 8192 \
  --max-num-seqs 2 \
  --kv-transfer-config $kvconfig_p1 \
  --gpu-memory-utilization 0.7 \
  --tensor-parallel-size 1 >$detaillog".device."$device_p1".txt" 2>&1 &  
pid_p1=$!

device_dd_list="1 2 3 4 5 6 7"
pid_d_all=""

for device_dd in $device_dd_list; do
  kvconfig_dd='{"kv_buffer_size":1e10,"kv_connector":"P2pConnector","kv_role":"kv_consumer","kv_port":"'${port_kv_train[$device_dd]}'","kv_connector_extra_config":{"proxy_ip":"0.0.0.0","proxy_port":"'${port_proxy_inner[$device_dd]}'","http_port":"'${port_server[$device_dd]}'"}}'        
  VLLM_RPC_TIMEOUT=1800000 \
  VLLM_TORCH_PROFILER_DIR=$result_folder \
  VLLM_USE_V1=0 \
  NCCL_CUMEM_ENABLE=1 \
  CUDA_VISIBLE_DEVICES=$device_dd \
  **vllm serve $model \
    --chat-template template.jinja \
    --port ${port_server[$device_dd]} \
    --dtype float16 \
    --trust-remote-code \
    --max-model-len 8192 \
    --max-seq-len-to-capture 8192 \
    --kv-transfer-config $kvconfig_dd \
    --gpu-memory-utilization 0.7 \
    --tensor-parallel-size 1** >$detaillog".device."$device_dd".txt" 2>&1 &  
  pid_dd=$!
  pid_d_all="$pid_d_all $pid_dd"
done

prefill的metric如下

INFO 04-28 10:18:07 [metrics.py:489] Avg prompt throughput: 10376.3 tokens/s, Avg generation throughput: 10.3 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 1 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%.
INFO 04-28 10:18:13 [metrics.py:489] Avg prompt throughput: 8967.2 tokens/s, Avg generation throughput: 8.4 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%.

decoder的metric基本都表明存在大量pending的seqs:

INFO 04-28 10:18:22 [metrics.py:489] Avg prompt throughput: 447.3 tokens/s, Avg generation throughput: 4764.9 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 153 reqs, GPU KV cache usage: 52.4%, CPU KV cache usage: 0.0%.
INFO 04-28 10:19:39 [metrics.py:489] Avg prompt throughput: 3675.0 tokens/s, Avg generation throughput: 3252.0 tokens/s, Running: 255 reqs, Swapped: 0 reqs, Pending: 143 reqs, GPU KV cache usage: 78.6%, CPU KV cache usage: 0.0%.

**INFO 04-28 10:18:58 [metrics.py:489] Avg prompt throughput: 653.8 tokens/s, Avg generation throughput: 3582.7 tokens/s, Running: 256 reqs, Swapped: 0 reqs, Pending: 247 reqs, GPU KV cache usage: 99.3%, CPU KV cache usage: 0.0%.**
INFO 04-28 10:19:18 [metrics.py:489] Avg prompt throughput: 1645.8 tokens/s, Avg generation throughput: 3356.9 tokens/s, Running: 255 reqs, Swapped: 0 reqs, Pending: 257 reqs, GPU KV cache usage: 94.1%, CPU KV cache usage: 0.0%.

其中一个decoder的接收溢出报错如下:

WARNING 04-28 10:15:05 [p2p_nccl_pipe.py:204] 🔴[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_1025d17ada3146a4aaa0e80fb3820e5e-0kv, duration:0.010ms, rank:0
INFO 04-28 10:15:05 [p2p_nccl_pipe.py:197] 🔵[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_1025d17ada3146a4aaa0e80fb3820e5e-0hidden, shape:torch.Size([2253, 3584]), duration:0.003ms, size:0.015GB, rank:0
**WARNING 04-28 10:15:05 [p2p_connector.py:164] [rank0]: Failed to receive all KVs and hidden states, redo model forwarding.**
WARNING 04-28 10:15:06 [p2p_nccl_pipe.py:204] 🔴[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_9e1b74521bf8450292b84bfa437ce78a-0kv, duration:0.009ms, rank:0
INFO 04-28 10:15:06 [p2p_nccl_pipe.py:197] 🔵[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_9e1b74521bf8450292b84bfa437ce78a-0hidden, shape:torch.Size([1447, 3584]), duration:0.003ms, size:0.010GB, rank:0
WARNING 04-28 10:15:06 [p2p_connector.py:164] [rank0]: Failed to receive all KVs and hidden states, redo model forwarding.
WARNING 04-28 10:15:06 [p2p_nccl_pipe.py:204] 🔴[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_b099c336345e4f2e87e417e04c3525e0-0kv, duration:0.021ms, rank:0
INFO 04-28 10:15:06 [p2p_nccl_pipe.py:197] 🔵[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_b099c336345e4f2e87e417e04c3525e0-0hidden, shape:torch.Size([746, 3584]), duration:0.004ms, size:0.005GB, rank:0
WARNING 04-28 10:15:06 [p2p_connector.py:164] [rank0]: Failed to receive all KVs and hidden states, redo model forwarding.
WARNING 04-28 10:15:06 [p2p_nccl_pipe.py:204] 🔴[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_5219d7b1e23847ad9049b9851a914d89-0kv, duration:0.013ms, rank:0
INFO 04-28 10:15:06 [p2p_nccl_pipe.py:197] 🔵[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_5219d7b1e23847ad9049b9851a914d89-0hidden, shape:torch.Size([2046, 3584]), duration:0.009ms, size:0.014GB, rank:0
WARNING 04-28 10:15:06 [p2p_nccl_pipe.py:204] 🔴[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.226:21001___decode_addr_100.64.160.226:25001_50506d595aef469399fba0b7578c89ab-0kv, duration:0.006ms, rank:0
INFO 04-28 10:15:06 [p2p_nccl_pipe.py:197] 🔵[PUT]Recv From 100.64.160.226:21001, tensor_id:cmpl-___prefill_addr_100.64.160.2

使用4张卡,基线配置,是可以承担10QPS的请求的,基线的kv-cache使用情况如下:

INFO 04-24 07:46:41 [loggers.py:87] Engine 000: Avg prompt throughput: 3718.2 tokens/s, Avg generation throughput: 2926.4 tokens/s, Running: 158 reqs, Waiting: 0 reqs, GPU KV cache usage: 26.5%, Prefix cache hit rate: 23.8%
INFO 04-24 07:46:51 [loggers.py:87] Engine 000: Avg prompt throughput: 2669.6 tokens/s, Avg generation throughput: 3549.1 tokens/s, Running: 160 reqs, Waiting: 0 reqs, GPU KV cache usage: 27.3%, Prefix cache hit rate: 23.8%

基线的启动命令如下


device_i0=4
    port_i0=8100
    CUDA_VISIBLE_DEVICES=$device_i0 \
    python3 -m vllm.entrypoints.openai.api_server \
      --model $model \
      --chat-template template.jinja \
      --port $port_i0 \
      --dtype float16 \
      --enforce-eager \
      --trust-remote-code \
      --gpu-memory-utilization 0.9 \
      --tensor_parallel_size 1 >$detaillog".device."$device_i0".txt" 2>&1 &  
    pid_i0=$!

    device_i1=5
    port_i1=8200
    CUDA_VISIBLE_DEVICES=$device_i1 \
    python3 -m vllm.entrypoints.openai.api_server \
      --model $model \
      --chat-template template.jinja \
      --port $port_i1 \
      --dtype float16 \
      --enforce-eager \
      --trust-remote-code \
      --gpu-memory-utilization 0.9 \
      --tensor_parallel_size 1 >$detaillog".device."$device_i1".txt" 2>&1 &  
    pid_i1=$!

    device_i2=6
    port_i2=8300
    CUDA_VISIBLE_DEVICES=$device_i2 \
    python3 -m vllm.entrypoints.openai.api_server \
      --model $model \
      --chat-template template.jinja \
      --port $port_i2 \
      --dtype float16 \
      --enforce-eager \
      --trust-remote-code \
      --gpu-memory-utilization 0.9 \
      --tensor_parallel_size 1 >$detaillog".device."$device_i2".txt" 2>&1 &  
    pid_i2=$!

    device_i3=7
    port_i3=8400
    CUDA_VISIBLE_DEVICES=$device_i3 \
    python3 -m vllm.entrypoints.openai.api_server \
      --model $model \
      --chat-template template.jinja \
      --port $port_i3 \
      --dtype float16 \
      --enforce-eager \
      --trust-remote-code \
      --gpu-memory-utilization 0.9 \
      --tensor_parallel_size 1 >$detaillog".device."$device_i3".txt" 2>&1 &  
    pid_i3=$!

@Abatom
Copy link
Contributor Author

Abatom commented Apr 29, 2025

使用的模型是: Qwen2-7B, 数据集合是一个输入平均为1000, 输出为2000~3000的任务, 执行前期1000个prompts不会报错, 到了1.6W个prompts左右开始会报错.

@WXFMAV 这是因为现在的proxy还比较简单,只是随机选择D实例,刚开始还比较均匀,后面就不均匀了

aoyulong pushed a commit to FlagOpen/FlagScale that referenced this pull request Apr 29, 2025
### Description: Multi-node Prefill/Decode Disaggregated Deployment with
FlagCX

This PR implements support for multi-node disaggregated deployment of
**prefill** and **decode** stages using `xPyD` Disaggregation:
- Schedule strategies of PD instances currently support: `robin`,
`random`. default is `robin`.
- It introduces a new communication backend based on
[FlagCX](https://github.com/FlagOpen/FlagCX). Merge [FlagCX
Adapter](#461).
- KV cache transfer is enabled via
[p2pConnector](vllm-project/vllm#15806) in
`vLLM`.


---

### How to Use

**Step 1**: Install
[FlagCX](https://github.com/FlagOpen/FlagCX?tab=readme-ov-file#quick-start)

**Step 2**: Install the `vLLM` version from
[FlagScale](https://github.com/FlagOpen/FlagScale?tab=readme-ov-file#setup)

**Step 3**: Define your config files under `./examples/qwen/conf`

**Step 4**: Launch the distributed deployment  
```bash
python run.py --config-path ./examples/qwen/conf --config-name config_qwen2.5_7b_disagg_xpyd action=run
```

**Step 5**: Send requests to the deployed service  
```bash
curl -X POST -s http://localhost:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{
  "model": "/models/Qwen2.5-7B-Instruct",
  "prompt": "Introduce Bruce Lee in details",
  "max_tokens": 100,
  "temperature": 0,
  "stream": true
}'
```
@Abatom Abatom closed this May 16, 2025
@Abatom Abatom changed the title [WIP][V1/0][P/D] XpYd based on p2p communication without cache store [V1/0][P/D] XpYd based on p2p communication without cache store May 16, 2025
@Abatom Abatom deleted the xpyd branch June 25, 2025 02:24
@Abatom Abatom restored the xpyd branch June 25, 2025 02:25
@Abatom Abatom deleted the xpyd branch July 2, 2025 03:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants