Skip to content

Conversation

yiz-liu
Copy link
Collaborator

@yiz-liu yiz-liu commented Jun 28, 2025

What this PR does / why we need it?

This pull request introduces full-graph capture, replacing the previous piecewise-graph approach. Key improvements include:

  • Reduced dispatch latency: By capturing the entire model execution graph at once, we minimize overhead compared to multiple smaller captures.
  • Stabilized multi-GPU performance: Eliminates throughput fluctuations during the MODEL_EXECUTE phase across multiple cards.
  • Stream resource savings: Consolidating graph captures frees up streams, allowing more graphs to be captured concurrently.
    Known issues:
  1. Capturing graphs increases GPU memory usage, which can lead to OOM errors or inference hangs.
  2. The new paged-attention implementation relies on the FIA operator, which in certain workloads is slower than the previous approach—resulting in a regression in end-to-end throughput.
    There may be other undiscovered corner cases. This PR is the first in a planned series; we will continue to iterate on and address any remaining issues in subsequent submissions.

Does this PR introduce any user-facing change?

compilation_config={
    "full_cuda_graph": True,
},

How was this patch tested?

@yiz-liu yiz-liu force-pushed the feat-full-graph branch 3 times, most recently from 34e1ac7 to 45d59fd Compare July 4, 2025 02:07
@yiz-liu yiz-liu force-pushed the feat-full-graph branch 2 times, most recently from 5987715 to ffdb493 Compare July 4, 2025 08:32
@yiz-liu yiz-liu changed the title [WIP][Enhancement] Implement primal full graph with limited scenario [Feat] Implement primal full graph with limited scenario Jul 4, 2025
Copy link

github-actions bot commented Jul 4, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@huyz-git
Copy link

huyz-git commented Jul 7, 2025

Should this part also be changed when full graph is enabled? Does it still need to divide by num_hidden_layers in full graph mode?

max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
(num_hidden_layers + 1) / parallel_factor)

@yiz-liu yiz-liu force-pushed the feat-full-graph branch from 7062bd3 to e4ea639 Compare July 7, 2025 03:10
@ganyi1996ppo ganyi1996ppo merged commit 04e6169 into vllm-project:v0.9.1-dev Jul 7, 2025
16 checks passed
@yiz-liu yiz-liu changed the title [Feat] Implement primal full graph with limited scenario [1/N][Feat] Implement primal full graph with limited scenario Jul 7, 2025
@Yikun Yikun added the no-main label Jul 7, 2025
@yiz-liu yiz-liu deleted the feat-full-graph branch July 8, 2025 10:56
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Jul 31, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 1, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 11, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 11, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 12, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 12, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 13, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 15, 2025
…t#1503)

This pull request introduces full-graph capture, replacing the previous
piecewise-graph approach. Key improvements include:

* **Reduced dispatch latency:** By capturing the entire model execution
graph at once, we minimize overhead compared to multiple smaller
captures.
* **Stabilized multi-GPU performance:** Eliminates throughput
fluctuations during the `MODEL_EXECUTE` phase across multiple cards.
* **Stream resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured concurrently.
**Known issues:**

1. Capturing larger or more numerous graphs increases GPU memory usage,
which can lead to OOM errors or inference hangs.
2. The new paged-attention implementation relies on the FIA operator,
which in certain workloads is slower than the previous
approach—resulting in a regression in end-to-end throughput.
There may be other undiscovered corner cases. This PR is the first in a
planned series; we will continue to iterate on and address any remaining
issues in subsequent submissions.

```python
compilation_config={
    "full_cuda_graph": True,
},
```

---------

Signed-off-by: Yizhou Liu <[email protected]>
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Sep 17, 2025
…th the latest design

Revert "[Feat] Implement primal full graph with limited scenario (vllm-project#1503)"

This reverts commit 14660be.

Signed-off-by: Yizhou Liu <[email protected]>
wangxiyuan pushed a commit that referenced this pull request Sep 22, 2025
Note: This depends on [vLLM
#25161](vllm-project/vllm#25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of #1503 and #1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.


- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@9607d5e

---------

Signed-off-by: Yizhou Liu <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…m-project#2128)

Note: This depends on [vLLM
#25161](vllm-project/vllm#25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of vllm-project#1503 and vllm-project#1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@9607d5e

---------

Signed-off-by: Yizhou Liu <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…m-project#2128)

Note: This depends on [vLLM
#25161](vllm-project/vllm#25161) and the
torch\_npu release from September 30.

### What this PR does / why we need it?
This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA
models like DeepSeek V3/R1 are not included). Key improvements include:

* **Reduced dispatch latency:** By replaying the entire model execution
graph at once, we cut overhead compared with multiple smaller replays.
* **Stabilized multi-device performance:** Captureing the whole model as
one static graph also mitigates the dispatch fluctuations across
devices.
* **Stream/resource savings:** Consolidating graph captures frees up
streams, allowing more graphs to be captured.

**Known issues:**

1. `_npu_paged_attention` currently manages its own workspace in
`torch_npu`, which can deadlock when synchronizing during graph replay —
we’re working on a fix.

There may be other corner cases. This PR is the first in a planned
series; we’ll continue to iterate and address remaining issues in
follow-ups.

This is essentially a port of vllm-project#1503 and vllm-project#1677, but includes two major
changes:

1. Let `graph_dispatcher` decide the graph mode instead of hard-coding
it in the backend, which decouples Full Graph and Piecewise Graph and
could make it possible to remove dynamo.
2. Adapt to the new `attn_group` logic, but leave a small hack in
`update_graph_params`; multi-attention models may or may not be fully
supported yet.

### Does this PR introduce _any_ user-facing change?
```python
compilation_config={
    "cudagraph_mode": "FULL_DECODE_ONLY",
},
```

### How was this patch tested?
Tests included.

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@9607d5e

---------

Signed-off-by: Yizhou Liu <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants