Skip to content

Commit 31bca0d

Browse files
authored
Merge branch 'main' into zhxchen17/precompile/2
2 parents 7935089 + cddce79 commit 31bca0d

File tree

9 files changed

+267
-112
lines changed

9 files changed

+267
-112
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
198198
compilation_config=CompilationConfig(
199199
level=CompilationLevel.PIECEWISE,
200200
use_cudagraph=True,
201-
splitting_ops=["silly.attention"],
201+
splitting_ops=["silly::attention"],
202202
cudagraph_capture_sizes=[1, 2],
203203
)
204204
)
@@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
267267
compilation_config=CompilationConfig(
268268
level=CompilationLevel.PIECEWISE,
269269
use_cudagraph=False,
270-
splitting_ops=["silly.attention"],
270+
splitting_ops=["silly::attention"],
271271
)
272272
)
273273
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE

tests/compile/piecewise/test_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _run_simple_model(
127127
@torch.inference_mode()
128128
def test_simple_piecewise_compile(use_inductor):
129129
_run_simple_model(
130-
splitting_ops=["silly.attention"],
130+
splitting_ops=["silly::attention"],
131131
use_inductor_graph_partition=False,
132132
use_inductor=use_inductor,
133133
# 2 * num_layers + 1
@@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
142142

143143

144144
@torch.inference_mode()
145-
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
145+
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
146146
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
147147
if not is_torch_equal_or_newer("2.9.0.dev"):
148148
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

tests/compile/piecewise/test_toy_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def run_model(
268268
cudagraph_capture_sizes=[1, 2],
269269
)
270270
if split_attn:
271-
compilation_config.splitting_ops = ["silly.attention"]
271+
compilation_config.splitting_ops = ["silly::attention"]
272272
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
273273
else:
274274
compilation_config = CompilationConfig(
@@ -438,7 +438,7 @@ def benchmark():
438438
compilation_config = CompilationConfig(
439439
level=CompilationLevel.PIECEWISE,
440440
use_cudagraph=True,
441-
splitting_ops=["silly.attention"],
441+
splitting_ops=["silly::attention"],
442442
cudagraph_capture_sizes=cudagraph_sizes,
443443
)
444444
else:

tests/compile/test_config.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from vllm.compilation.counter import compilation_counter
66
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
7-
from vllm.utils import _is_torch_equal_or_newer
7+
from vllm.config.compilation import CompilationLevel
8+
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
89

910

1011
def test_version():
12+
# Test the version comparison logic using the private function
1113
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
1214
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
1315
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
@@ -17,6 +19,9 @@ def test_version():
1719

1820
def test_use_cudagraphs_dynamic():
1921
vllm_config = VllmConfig()
22+
# Default V1 configuration now starts without cudagraphs enabled; the
23+
# engine decides when to capture based on runtime settings instead of a
24+
# blanket default.
2025
assert vllm_config.compilation_config.use_cudagraph
2126

2227

@@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
137142
def test_splitting_ops_dynamic():
138143
# Default config
139144
config = VllmConfig()
140-
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
141-
assert config.compilation_config.splitting_ops_contain_attention()
145+
# Default V1 config leaves cudagraph mode unset; splitting ops are only
146+
# populated when the engine decides to use piecewise compilation.
147+
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
148+
assert not config.compilation_config.splitting_ops_contain_attention()
142149

143150
# When use_inductor_graph_partition=True
144-
if _is_torch_equal_or_newer("2.9.0.dev"):
145-
# inductor graph partition is only available in PyTorch 2.9+.
146-
# this is a fast config check so we are not using pytest.skip.
151+
if is_torch_equal_or_newer("2.9.0.dev"):
147152
config = VllmConfig(
148153
compilation_config=CompilationConfig(
149-
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
154+
level=CompilationLevel.PIECEWISE,
155+
use_inductor_graph_partition=True,
156+
splitting_ops=["vllm::unified_attention"],
150157
)
151158
)
152-
# should ignore splitting_ops
153-
assert config.compilation_config.splitting_ops == []
159+
# with inductor partition we use splitting_ops directly for
160+
# partition rules
161+
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
154162

155-
# When attn_fusion pass enabled.
163+
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
156164
config = VllmConfig(
157165
compilation_config=CompilationConfig(
166+
level=CompilationLevel.PIECEWISE,
158167
pass_config={"enable_attn_fusion": True, "enable_noop": True},
159168
custom_ops=["+quant_fp8"],
160169
cudagraph_mode=CUDAGraphMode.PIECEWISE,
161170
)
162171
)
163-
assert config.compilation_config.splitting_ops == []
164-
# cudagraph mode also fall back to FULL
165-
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
166-
167-
# splitting_ops can not contain attention ops when attn_fusion
168-
# pass enabled.
169-
with pytest.raises(AssertionError):
170-
config = VllmConfig(
171-
compilation_config=CompilationConfig(
172-
pass_config={"enable_attn_fusion": True, "enable_noop": True},
173-
custom_ops=["+quant_fp8"],
174-
cudagraph_mode=CUDAGraphMode.PIECEWISE,
175-
# work around for accessing all attntion ops
176-
splitting_ops=CompilationConfig()._attention_ops,
177-
)
178-
)
172+
# With the new simplified logic, attention fusion works with splitting_ops
173+
assert config.compilation_config.splitting_ops_contain_attention()
174+
# cudagraph mode remains PIECEWISE
175+
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
179176

180177
# When both use_inductor_graph_partition and attn_fusion pass enabled.
181-
if _is_torch_equal_or_newer("2.9.0.dev"):
178+
if is_torch_equal_or_newer("2.9.0.dev"):
182179
config = VllmConfig(
183180
compilation_config=CompilationConfig(
181+
level=CompilationLevel.PIECEWISE,
184182
use_inductor_graph_partition=True,
185183
pass_config={"enable_attn_fusion": True, "enable_noop": True},
186184
custom_ops=["+quant_fp8"],
187185
cudagraph_mode=CUDAGraphMode.PIECEWISE,
188186
)
189187
)
190-
assert config.compilation_config.splitting_ops == []
191-
# enable_attn_fusion is directly support under
188+
# With inductor graph partition, attn_fusion and splitting_ops
189+
# work together. Default splitting_ops include attention ops.
190+
assert config.compilation_config.splitting_ops_contain_attention()
191+
# enable_attn_fusion is directly supported under
192192
# use_inductor_graph_partition=True, and cudagraph_mode
193193
# is unchanged.
194194
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
195+
196+
197+
def test_resolve_operator_overload():
198+
import torch
199+
200+
from vllm.compilation.partition_rules import resolve_defined_ops
201+
202+
# Test valid operator names
203+
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
204+
assert len(resolved) == 2
205+
assert resolved[0] is torch.ops.aten.mm.default
206+
assert resolved[1] is torch.ops.aten.addmm.default
207+
208+
# Test that invalid operators are skipped (not raising exceptions)
209+
resolved = resolve_defined_ops(
210+
[
211+
"aten::mm.default",
212+
"aten::nonexistent_op.default", # This should be skipped
213+
"aten::addmm.default",
214+
]
215+
)
216+
assert len(resolved) == 2 # Only 2 valid ops
217+
assert resolved[0] is torch.ops.aten.mm.default
218+
assert resolved[1] is torch.ops.aten.addmm.default

tests/compile/test_decorator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
7171
compilation_config=CompilationConfig(
7272
level=CompilationLevel.PIECEWISE,
7373
use_cudagraph=True,
74-
splitting_ops=["silly.attention"],
74+
splitting_ops=["silly::attention"],
7575
cudagraph_capture_sizes=[1, 2],
7676
)
7777
)
@@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
186186
compilation_config=CompilationConfig(
187187
level=CompilationLevel.PIECEWISE,
188188
use_cudagraph=True,
189-
splitting_ops=["silly.attention"],
189+
splitting_ops=["silly::attention"],
190190
cudagraph_capture_sizes=[1, 2],
191191
),
192192
)
@@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
218218
compilation_config=CompilationConfig(
219219
level=CompilationLevel.PIECEWISE,
220220
use_cudagraph=True,
221-
splitting_ops=["silly.attention"],
221+
splitting_ops=["silly::attention"],
222222
cudagraph_capture_sizes=[1, 2],
223223
),
224224
)

vllm/compilation/backends.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from torch._dispatch.python import enable_python_dispatcher
1717

1818
import vllm.envs as envs
19+
from vllm.compilation.inductor_pass import pass_context
20+
from vllm.compilation.partition_rules import (
21+
inductor_partition_rule_context,
22+
resolve_defined_ops,
23+
)
1924
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
2025
from vllm.logger import init_logger
2126
from vllm.platforms import current_platform
@@ -78,6 +83,21 @@ def __init__(self, compilation_config: CompilationConfig):
7883
def compute_hash(self, vllm_config: VllmConfig) -> str:
7984
return self.compiler.compute_hash(vllm_config)
8085

86+
@contextmanager
87+
def compile_context(self, runtime_shape: Optional[int] = None):
88+
"""Provide compilation context for the duration of compilation to set
89+
any torch global properties we want to scope to a single Inductor
90+
compilation (e.g. partition rules, pass context)."""
91+
with pass_context(runtime_shape):
92+
if self.compilation_config.use_inductor_graph_partition:
93+
inductor_partition_ops = resolve_defined_ops(
94+
self.compilation_config.splitting_ops
95+
)
96+
with inductor_partition_rule_context(inductor_partition_ops):
97+
yield
98+
else:
99+
yield
100+
81101
def initialize_cache(
82102
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
83103
):
@@ -200,9 +220,15 @@ def compile(
200220
maybe_key = None
201221
else:
202222
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
203-
compiled_graph, handle = self.compiler.compile(
204-
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
205-
)
223+
224+
with self.compile_context(runtime_shape):
225+
compiled_graph, handle = self.compiler.compile(
226+
graph,
227+
example_inputs,
228+
additional_inductor_config,
229+
runtime_shape,
230+
maybe_key,
231+
)
206232

207233
assert compiled_graph is not None, "Failed to compile the graph"
208234

@@ -261,7 +287,7 @@ class SplitItem:
261287

262288

263289
def split_graph(
264-
graph: fx.GraphModule, ops: list[str]
290+
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
265291
) -> tuple[fx.GraphModule, list[SplitItem]]:
266292
# split graph by ops
267293
subgraph_id = 0
@@ -270,7 +296,12 @@ def split_graph(
270296
for node in graph.graph.nodes:
271297
if node.op in ("output", "placeholder"):
272298
continue
273-
if node.op == "call_function" and str(node.target) in ops:
299+
# Match node.target against resolved_ops
300+
# node.target can be OpOverloadPacket, need to check .default
301+
if node.op == "call_function" and (
302+
node.target in resolved_ops
303+
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
304+
):
274305
subgraph_id += 1
275306
node_to_subgraph_id[node] = subgraph_id
276307
split_op_graphs.append(subgraph_id)
@@ -594,9 +625,14 @@ def __call__(
594625
self.graph = graph
595626
self.configure_post_pass()
596627

597-
self.split_gm, self.piecewise_graphs = split_graph(
598-
graph, self.compilation_config.splitting_ops
599-
)
628+
if self.compilation_config.use_inductor_graph_partition:
629+
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
630+
fx_split_ops: list[str] = []
631+
else:
632+
fx_split_ops = self.compilation_config.splitting_ops or []
633+
634+
resolved_split_ops = resolve_defined_ops(fx_split_ops)
635+
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
600636

601637
from torch._dynamo.utils import lazy_format_graph_code
602638

vllm/compilation/compiler_interface.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm.config import VllmConfig
1818
from vllm.utils import is_torch_equal_or_newer
1919

20-
from .inductor_pass import pass_context
21-
2220

2321
class CompilerInterface:
2422
"""
@@ -210,13 +208,12 @@ def compile(
210208

211209
from torch._inductor import standalone_compile
212210

213-
with pass_context(runtime_shape):
214-
compiled_graph = standalone_compile(
215-
graph,
216-
example_inputs,
217-
dynamic_shapes=dynamic_shapes,
218-
options={"config_patches": current_config},
219-
)
211+
compiled_graph = standalone_compile(
212+
graph,
213+
example_inputs,
214+
dynamic_shapes=dynamic_shapes,
215+
options={"config_patches": current_config},
216+
)
220217

221218
# Save the compiled artifact to disk in the specified path
222219
assert key is not None
@@ -464,13 +461,12 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
464461
torch._functorch.config.patch(enable_remote_autograd_cache=False)
465462
)
466463

467-
with pass_context(runtime_shape):
468-
compiled_graph = compile_fx(
469-
graph,
470-
example_inputs,
471-
inner_compile=hijacked_compile_fx_inner,
472-
config_patches=current_config,
473-
)
464+
compiled_graph = compile_fx(
465+
graph,
466+
example_inputs,
467+
inner_compile=hijacked_compile_fx_inner,
468+
config_patches=current_config,
469+
)
474470

475471
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
476472
# compilation cache. So turn off the checks if we disable the

0 commit comments

Comments
 (0)