Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4f6e1b4
init
BoyuanFeng Sep 4, 2025
1c1b600
nit
BoyuanFeng Sep 4, 2025
50d1dda
nit
BoyuanFeng Sep 4, 2025
7218e2b
nit
BoyuanFeng Sep 5, 2025
71209e2
cleanup
BoyuanFeng Sep 5, 2025
202b6f3
add doc
BoyuanFeng Sep 5, 2025
0b1e18a
improve warn/error msg
BoyuanFeng Sep 5, 2025
b66568b
match new torch api
BoyuanFeng Sep 5, 2025
87c74dd
skip cudagraph for get_input_embedding
BoyuanFeng Sep 9, 2025
c0bd3fb
Update vllm/compilation/backends.py
BoyuanFeng Sep 11, 2025
e16e23a
Apply suggestions from code review
BoyuanFeng Sep 11, 2025
892ab46
more docs
BoyuanFeng Sep 11, 2025
eabb1b6
nit
BoyuanFeng Sep 11, 2025
04e9801
Update vllm/v1/cudagraph_dispatcher.py
BoyuanFeng Sep 12, 2025
6cf5bd5
add piecewise test
BoyuanFeng Sep 14, 2025
70f45da
lint
BoyuanFeng Sep 14, 2025
7eb5d57
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 15, 2025
3a6abd8
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 15, 2025
4cce30c
add custom compile config test
BoyuanFeng Sep 15, 2025
d3809fb
more tests for splitting_ops
BoyuanFeng Sep 16, 2025
d7a73db
add tests for attention_quant_pattern
BoyuanFeng Sep 17, 2025
289a60e
rearch is_attention_compiled_piecewise
BoyuanFeng Sep 17, 2025
29ae5f0
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 18, 2025
b5972fa
move set/unset wrapper to support_torch_compile for frame-specific
BoyuanFeng Sep 18, 2025
7570f4b
update test_attention_quant_pattern
BoyuanFeng Sep 18, 2025
c7ff7c4
Update vllm/config/compilation.py
BoyuanFeng Sep 18, 2025
4a38b36
more tests
BoyuanFeng Sep 18, 2025
d4269d9
move wrapper set/unset to context manager
BoyuanFeng Sep 18, 2025
20b9ef1
nit
BoyuanFeng Sep 18, 2025
e055458
update test
BoyuanFeng Sep 19, 2025
45b7588
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 19, 2025
91c03a4
move maybe_use_cudagraph_partition_wrapper to decorators.py
BoyuanFeng Sep 19, 2025
19787d3
test inductor graph partition only when >= torch2.9
BoyuanFeng Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1

def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
use_inductor,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=["silly.attention"],
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
Expand All @@ -70,9 +74,10 @@ def test_simple_piecewise_compile(use_inductor):

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=
expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context(None,
Expand Down Expand Up @@ -104,3 +109,37 @@ def test_simple_piecewise_compile(use_inductor):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model(
splitting_ops=["silly.attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
)


@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
_run_simple_model(
# inductor graph partition automatically resets splitting_ops
# to be an empty list
splitting_ops=splitting_ops,
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
)
1 change: 1 addition & 0 deletions tests/compile/silly_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
16 changes: 15 additions & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer

from ..utils import create_new_process_for_each_test

Expand Down Expand Up @@ -107,6 +109,18 @@ def test_full_graph(
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
] + [
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
])
# only test some of the models
@create_new_process_for_each_test()
Expand Down
41 changes: 31 additions & 10 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import logging
from typing import Optional

import pytest
Expand Down Expand Up @@ -339,6 +340,10 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
Expand All @@ -352,7 +357,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
"""Test AttentionStaticQuantPattern fusion pass"""

monkeypatch.setenv("VLLM_USE_V1", "1")
Expand All @@ -372,6 +378,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),
cache_config=CacheConfig(cache_dtype="fp8"))

Expand Down Expand Up @@ -407,9 +414,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
vllm_config=vllm_config_unfused)
model_unfused = model_unfused.to(device)

forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention)
# TODO(boyuan): the attn_metadata with quantization does not
# work on my server. So skip for inductor graph partition
# test for now.
if not use_inductor_graph_partition:
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention)

# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
Expand All @@ -429,9 +440,11 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
w=model_unfused.w)
model_fused = model_fused.to(device)

forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)
# TODO(boyuan)
if not use_inductor_graph_partition:
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)

# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
Expand All @@ -444,16 +457,24 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)

with caplog_vllm.at_level(logging.DEBUG):
result_fused_1 = model_compiled(q, k, v)

if backend == _Backend.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
if use_inductor_graph_partition:
assert ("Fused quantization onto 1 attention nodes"
in caplog_vllm.text)
else:
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None

if not use_inductor_graph_partition:
assert model_compiled.attn._o_scale_float is not None

torch.testing.assert_close(result_unfused,
result_fused_2,
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def unified_attention_fake(
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)


Expand Down Expand Up @@ -625,4 +626,5 @@ def unified_attention_with_output_fake(
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
43 changes: 41 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,40 @@ def call_module(self, target: torch.fx.node.Target,
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time

if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.use_inductor_graph_partition):
# If we're using Inductor-based graph partitioning, we currently
# have the whole `fx.Graph` before Inductor lowering and
# and the piecewise splitting happens after all graph
# passes and fusions. Here, we add a custom hook for Inductor
# to wrap each partition with our static graph wrapper class to
# maintain more control over static graph capture and replay.

from torch._inductor.utils import CUDAGraphWrapperMetadata

from .cuda_graph import CUDAGraphOptions

static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())

def customized_cudagraph_wrapper(
f, metadata: CUDAGraphWrapperMetadata):
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
))

torch._inductor.utils.set_customized_partition_wrappers(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the lifetime of this value? Do we need to unset it at some point or is it automatically scoped to this compilation only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lifetime is for the whole process. I.e., for all compiled functions, we will apply this wrapper to every graph partitions.

I'm currently rely on set_forward_context(cudagraph_runtime_mode=CUDAGraphMode.None) to turn off cudagraph for computing input embeddings in _preprocess.

https://github.com/vllm-project/vllm/pull/24281/files#diff-80ee7e2a62f9dcfbb8a312dc4e3948557e97ef187290daebbcae1e28596bda29R1786-R1790

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scoping it to the whole process means that testing becomes annoying. Does this need to be set during compile time or runtime?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has a great suggestion to move this set/unset wrapper to support_torch_compile. Now it is local to frame!

customized_cudagraph_wrapper)

compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
Expand All @@ -336,15 +370,20 @@ def call_module(self, target: torch.fx.node.Target,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend

piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)

if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions

# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
Expand Down
Loading