Skip to content

Commit 0c38f41

Browse files
committed
[Feat] Implement primal full graph with limited scenario (vllm-project#1503)
This pull request introduces full-graph capture, replacing the previous piecewise-graph approach. Key improvements include: * **Reduced dispatch latency:** By capturing the entire model execution graph at once, we minimize overhead compared to multiple smaller captures. * **Stabilized multi-GPU performance:** Eliminates throughput fluctuations during the `MODEL_EXECUTE` phase across multiple cards. * **Stream resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured concurrently. **Known issues:** 1. Capturing larger or more numerous graphs increases GPU memory usage, which can lead to OOM errors or inference hangs. 2. The new paged-attention implementation relies on the FIA operator, which in certain workloads is slower than the previous approach—resulting in a regression in end-to-end throughput. There may be other undiscovered corner cases. This PR is the first in a planned series; we will continue to iterate on and address any remaining issues in subsequent submissions. ```python compilation_config={ "full_cuda_graph": True, }, ``` --------- Signed-off-by: Yizhou Liu <[email protected]>
1 parent 9260910 commit 0c38f41

File tree

8 files changed

+315
-31
lines changed

8 files changed

+315
-31
lines changed

tests/e2e/singlecard/test_aclgraph.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@
3636

3737
@pytest.mark.parametrize("model", MODELS)
3838
@pytest.mark.parametrize("max_tokens", [32])
39+
@pytest.mark.parametrize("full_graph", [False])
3940
def test_models_with_aclgraph(
4041
model: str,
4142
max_tokens: int,
43+
full_graph: bool,
44+
monkeypatch: pytest.MonkeyPatch,
4245
) -> None:
4346
prompts = [
4447
"Hello, my name is", "The president of the United States is",
@@ -48,7 +51,15 @@ def test_models_with_aclgraph(
4851
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
4952
# TODO: change to use vllmrunner when the registry of custom op is solved
5053
# while running pytest
51-
vllm_model = LLM(model, max_model_len=1024)
54+
if full_graph:
55+
vllm_model = LLM(model,
56+
compilation_config={
57+
"full_cuda_graph": True,
58+
"cudagraph_capture_sizes":
59+
[1, 4, 16, 64, 256]
60+
})
61+
else:
62+
vllm_model = LLM(model, max_model_len=1024)
5263
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
5364
del vllm_model
5465
torch.npu.empty_cache()

vllm_ascend/attention/attention_v1.py

Lines changed: 137 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2525
AttentionLayer, AttentionType)
2626
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.config import get_current_vllm_config
2728
from vllm.forward_context import ForwardContext, get_forward_context
2829
from vllm.utils import direct_register_custom_op
2930
from vllm.v1.core.sched.output import SchedulerOutput
3031

32+
from vllm_ascend.attention.utils import \
33+
AscendCommonAttentionMetadata as CommonAttentionMetadata
3134
from vllm_ascend.ops.attention import vanilla_chunked_prefill
32-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
35+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, get_graph_params, is_310p,
3336
nd_to_nz_2d, nd_to_nz_spec)
3437
from vllm_ascend.worker.npu_input_batch import InputBatch
3538

@@ -132,7 +135,7 @@ class AscendMetadata:
132135
# tokens + new tokens (is None if it is a decoding).
133136
# (batch_size,)
134137
seq_lens: torch.Tensor = None
135-
138+
seq_lens_list: list
136139
query_start_loc: torch.Tensor = None
137140
query_lens: torch.Tensor = None
138141
# Maximum query length in the batch (None for decoding).
@@ -167,6 +170,7 @@ def build(self,
167170
num_reqs,
168171
num_actual_tokens,
169172
max_query_len,
173+
common_attn_metadata: CommonAttentionMetadata,
170174
enable_dbo_across_dp: bool = False,
171175
is_only_prefill: bool = False):
172176

@@ -175,15 +179,16 @@ def build(self,
175179
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
176180
block_table[:num_reqs])
177181

178-
query_lens = self.runner.query_lens
179-
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
180-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
181-
self.runner.device, non_blocking=True)
182+
query_start_loc = common_attn_metadata.query_start_loc
183+
seq_lens = common_attn_metadata.seq_lens
184+
# TODO: Refactor these two param to common metadata in runners,
185+
# preparing for the hybrid KV groups feature
186+
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
187+
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list
188+
189+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
182190
attn_mask = self.runner.attn_mask
183191
attn_state = self.runner.attn_state
184-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
185-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
186-
non_blocking=True)
187192

188193
if is_310p():
189194
if attn_state == AscendAttentionState.PrefillNoCache:
@@ -201,6 +206,7 @@ def build(self,
201206
query_start_loc=query_start_loc,
202207
query_lens=query_lens,
203208
seq_lens=seq_lens,
209+
seq_lens_list=seq_lens_list,
204210
max_query_len=max_query_len,
205211
slot_mapping=slot_mapping,
206212
attn_mask=attn_mask,
@@ -209,6 +215,34 @@ def build(self,
209215
is_only_prefill=is_only_prefill)
210216
return attn_metadata
211217

218+
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
219+
num_scheduled_tokens, attn_state):
220+
if attn_state == AscendAttentionState.DecodeOnly:
221+
# NOTE: We only need to pay attention to seq_lens_list and block_table here
222+
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
223+
num_reqs)
224+
225+
block_table = self.runner.input_batch.block_table[0].block_table
226+
block_table[:num_reqs, 0] = torch.arange(1,
227+
num_reqs + 1,
228+
device=block_table.device,
229+
dtype=block_table.dtype)
230+
231+
attn_metadata = self.build(
232+
num_reqs=num_reqs,
233+
num_actual_tokens=num_actual_tokens,
234+
max_query_len=num_scheduled_tokens.max(),
235+
common_prefix_len=0,
236+
common_attn_metadata=common_attn_metadata,
237+
)
238+
else:
239+
raise NotImplementedError(
240+
"Currently we only support building dummy metadata for DecodeOnly state"
241+
)
242+
243+
attn_metadata.attn_state = attn_state
244+
return attn_metadata
245+
212246

213247
class AscendAttentionBackendImpl(AttentionImpl):
214248

@@ -245,6 +279,10 @@ def __init__(
245279
self.key_cache = None
246280
self.value_cache = None
247281

282+
vllm_config = get_current_vllm_config()
283+
self.full_graph = vllm_config.compilation_config.full_cuda_graph
284+
self.block_size = vllm_config.cache_config.block_size
285+
248286
def forward(
249287
self,
250288
layer: AttentionLayer,
@@ -369,20 +407,96 @@ def forward(
369407
scale_value=self.scale,
370408
out=output)
371409
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
372-
if is_310p():
373-
# # seq_lens_tensor needs to be transferred to the device for 310P
374-
attn_metadata.seq_lens = \
375-
attn_metadata.seq_lens.to(device=query.device)
376-
torch_npu._npu_paged_attention(
377-
query=query,
378-
key_cache=self.key_cache,
379-
value_cache=self.value_cache,
380-
num_kv_heads=self.num_kv_heads,
381-
num_heads=self.num_heads,
382-
scale_value=self.scale,
383-
block_table=attn_metadata.block_tables,
384-
context_lens=attn_metadata.seq_lens,
385-
out=output)
410+
if self.full_graph:
411+
graph_params = get_graph_params()
412+
q = query.view(num_tokens, -1, self.hidden_size)
413+
k = self.key_cache.view( # type: ignore
414+
-1, self.block_size,
415+
self.num_kv_heads * self.head_size)
416+
v = self.value_cache.view( # type: ignore
417+
-1, self.block_size,
418+
self.num_kv_heads * self.head_size)
419+
actual_seq_lens = attn_metadata.seq_lens_list
420+
attn_args = {
421+
"query": q,
422+
"key": k,
423+
"value": v,
424+
"actual_seq_lengths_kv": actual_seq_lens,
425+
"block_table": attn_metadata.block_tables,
426+
"num_heads": self.num_heads,
427+
"scale": self.scale,
428+
"input_layout": "BSH",
429+
"num_key_value_heads": self.num_kv_heads,
430+
"block_size": self.block_size,
431+
}
432+
433+
# Prepare tensors for attention output
434+
# TODO: Refactor this to step-level instead of layer-level
435+
attn_output = torch.empty(num_tokens,
436+
1,
437+
self.hidden_size,
438+
dtype=output.dtype,
439+
device=output.device)
440+
softmax_lse = torch.empty(num_tokens,
441+
dtype=output.dtype,
442+
device=output.device)
443+
444+
# Get workspace from cache or calculate it if not present.
445+
workspace = graph_params.workspaces.get(num_tokens)
446+
if workspace is None:
447+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
448+
**attn_args)
449+
graph_params.workspaces[num_tokens] = workspace
450+
451+
forward_context = get_forward_context()
452+
if not forward_context.capturing:
453+
# Execute attention kernel directly in non-capturing mode
454+
torch.ops.npu.npu_fused_infer_attention_score.out(
455+
workspace=workspace,
456+
out=[attn_output, softmax_lse],
457+
**attn_args)
458+
else:
459+
# Handle graph capturing mode
460+
stream = torch_npu.npu.current_stream()
461+
462+
event = torch.npu.ExternalEvent()
463+
event.wait(stream)
464+
event.reset(stream)
465+
graph_params.events[num_tokens].append(event)
466+
467+
graph_params.attn_params[num_tokens].append(
468+
(q, k, v, actual_seq_lens,
469+
attn_metadata.block_tables, self.num_heads,
470+
self.scale, self.num_kv_heads, attn_output,
471+
softmax_lse))
472+
473+
torch.npu.graph_task_group_begin(stream)
474+
torch.ops.npu.npu_fused_infer_attention_score.out(
475+
workspace=workspace,
476+
out=[attn_output, softmax_lse],
477+
**attn_args)
478+
handle = torch.npu.graph_task_group_end(stream)
479+
graph_params.handles[num_tokens].append(handle)
480+
481+
# Reshape output to match the expected format
482+
output.copy_(
483+
attn_output.view(num_tokens, self.num_heads,
484+
self.head_size))
485+
else:
486+
if is_310p():
487+
# seq_lens_tensor needs to be transferred to the device for 310P
488+
attn_metadata.seq_lens = \
489+
attn_metadata.seq_lens.to(device=query.device)
490+
torch_npu._npu_paged_attention(
491+
query=query,
492+
key_cache=self.key_cache,
493+
value_cache=self.value_cache,
494+
num_kv_heads=self.num_kv_heads,
495+
num_heads=self.num_heads,
496+
scale_value=self.scale,
497+
block_table=attn_metadata.block_tables,
498+
context_lens=attn_metadata.seq_lens,
499+
out=output)
386500
# Normal V1 situation.
387501
else:
388502
# use chunked prefill for head size 192 scenario, like deepseek

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from vllm_ascend import envs
1818
from vllm_ascend.ascend_config import get_ascend_config
1919
from vllm_ascend.attention.attention_v1 import AscendAttentionState
20+
from vllm_ascend.attention.utils import \
21+
AscendCommonAttentionMetadata as CommonAttentionMetadata
2022
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2123
from vllm_ascend.multistream.context import get_multistream_comm_context
2224
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn

vllm_ascend/attention/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class AscendCommonAttentionMetadata:
9+
"""
10+
Attention metadata attributes that can be shared by layers in different KV
11+
cache groups and thus having different block table.
12+
"""
13+
14+
query_start_loc: torch.Tensor = None
15+
"""(batch_size + 1,), the start location of each request in query Tensor"""
16+
seq_lens: Optional[torch.Tensor] = None
17+
"""(batch_size,), the length of each request including both computed tokens
18+
and newly scheduled tokens"""
19+
query_lens: Optional[torch.Tensor] = None
20+
"""(batch_size,), the length of each request including only the newly
21+
scheduled tokens"""
22+
seq_lens_list: Optional[list] = None
23+
"""(num_input_tokens,), note that this is specifically for FIA kernel"""

vllm_ascend/compilation/piecewise_backend.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@
2828
from vllm.compilation.counter import compilation_counter
2929
from vllm.compilation.monitor import end_monitoring_torch_compile
3030
from vllm.config import VllmConfig
31+
from vllm.forward_context import get_forward_context
3132
from vllm.logger import logger
3233
from vllm.utils import weak_ref_tensors
3334

35+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
36+
from vllm_ascend.utils import get_graph_params, set_graph_params
37+
3438

3539
@dataclasses.dataclass
3640
class ConcreteSizeEntry:
@@ -95,6 +99,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
9599

96100
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
97101

102+
if self.compilation_config.full_cuda_graph:
103+
self.update_stream = torch.npu.Stream()
104+
set_graph_params(self.aclgraph_capture_sizes)
105+
98106
# the entries for different shapes that we need to either
99107
# compile or capture aclgraph
100108
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
@@ -116,7 +124,40 @@ def check_for_ending_compilation(self):
116124
self.vllm_backend.compiler_manager.save_to_file()
117125
end_monitoring_torch_compile(self.vllm_config)
118126

127+
def update_attn_params(self, graph_params, forward_context, runtime_shape):
128+
for layer_idx in range(len(graph_params.handles[runtime_shape])):
129+
query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[
130+
runtime_shape][layer_idx]
131+
block_table = forward_context.attn_metadata.block_tables
132+
actual_seq_lens = forward_context.attn_metadata.seq_lens_list
133+
134+
with torch.npu.stream(self.update_stream):
135+
torch.npu.graph_task_update_begin(
136+
self.update_stream,
137+
graph_params.handles[runtime_shape][layer_idx])
138+
torch.ops.npu.npu_fused_infer_attention_score.out(
139+
query,
140+
key,
141+
value,
142+
workspace=graph_params.workspaces[runtime_shape],
143+
actual_seq_lengths_kv=actual_seq_lens,
144+
block_table=block_table,
145+
num_heads=num_heads,
146+
scale=scale,
147+
input_layout="BSH",
148+
num_key_value_heads=num_kv_heads,
149+
block_size=128,
150+
out=[output, softmax_lse],
151+
)
152+
torch.npu.graph_task_update_end(self.update_stream)
153+
154+
graph_params.events[runtime_shape][layer_idx].record(
155+
self.update_stream)
156+
119157
def __call__(self, *args) -> Any:
158+
forward_context = get_forward_context()
159+
graph_params = get_graph_params()
160+
120161
if not self.first_run_finished:
121162
self.first_run_finished = True
122163
self.check_for_ending_compilation()
@@ -127,6 +168,11 @@ def __call__(self, *args) -> Any:
127168
# we don't need to do anything for this shape
128169
return self.compiled_graph_for_general_shape(*args)
129170

171+
if (getattr(forward_context.attn_metadata, "attn_state",
172+
None) != AscendAttentionState.DecodeOnly
173+
and self.compilation_config.full_cuda_graph):
174+
return self.compiled_graph_for_general_shape(*args)
175+
130176
entry = self.concrete_size_entries[runtime_shape]
131177

132178
if entry.runnable is None:
@@ -189,6 +235,7 @@ def __call__(self, *args) -> Any:
189235
patch("torch.npu.empty_cache", lambda: None))
190236

191237
# mind-exploding: carefully manage the reference and memory.
238+
forward_context.capturing = True
192239
with torch.npu.graph(aclgraph, pool=self.graph_pool):
193240
# `output` is managed by pytorch's aclgraph pool
194241
output = entry.runnable(*args)
@@ -222,4 +269,9 @@ def __call__(self, *args) -> Any:
222269
)
223270

224271
entry.aclgraph.replay()
272+
273+
if self.compilation_config.full_cuda_graph:
274+
self.update_attn_params(graph_params, forward_context,
275+
runtime_shape)
276+
225277
return entry.output

vllm_ascend/platform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
163163
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
164164
"using only ACL Graph mode")
165165
compilation_config.use_inductor = False
166-
compilation_config.splitting_ops.extend(
167-
["vllm.unified_ascend_attention_with_output"])
166+
if not compilation_config.full_cuda_graph:
167+
compilation_config.splitting_ops.extend(
168+
["vllm.unified_ascend_attention_with_output"])
168169
update_aclgraph_sizes(vllm_config)
169170

170171
if parallel_config and parallel_config.worker_cls == "auto":

0 commit comments

Comments
 (0)