Skip to content

Commit d5540b1

Browse files
zhxchen17xuebwang-amd
authored andcommitted
AOT Compilation for torch.compile (Bundled) (vllm-project#24274)
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent 9f8c1db commit d5540b1

File tree

9 files changed

+484
-41
lines changed

9 files changed

+484
-41
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ steps:
403403
- pytest -v -s compile/test_fusion_all_reduce.py
404404
- pytest -v -s compile/test_decorator.py
405405
- pytest -v -s compile/test_noop_elimination.py
406+
- pytest -v -s compile/test_aot_compile.py
406407

407408
- label: PyTorch Fullgraph Smoke Test # 15min
408409
timeout_in_minutes: 30

tests/compile/test_aot_compile.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import tempfile
5+
from contextlib import contextmanager
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationLevel,
14+
VllmConfig,
15+
set_current_vllm_config,
16+
)
17+
from vllm.forward_context import set_forward_context
18+
from vllm.utils import is_torch_equal_or_newer
19+
20+
21+
def reference_fn(x: torch.Tensor):
22+
assert x.shape[0] <= 42
23+
assert x.shape[0] % 2 == 0
24+
for _ in range(3000):
25+
x = x + x.shape[0]
26+
return x
27+
28+
29+
@support_torch_compile
30+
class CompiledMod(torch.nn.Module):
31+
def __init__(self, **kwargs):
32+
super().__init__()
33+
34+
def forward(self, x: torch.Tensor):
35+
return reference_fn(x)
36+
37+
38+
def make_vllm_config() -> VllmConfig:
39+
return VllmConfig(
40+
compilation_config=CompilationConfig(
41+
level=CompilationLevel.PIECEWISE,
42+
)
43+
)
44+
45+
46+
@contextmanager
47+
def use_vllm_config(vllm_config: VllmConfig):
48+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
49+
yield
50+
51+
52+
@pytest.mark.skipif(
53+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
54+
)
55+
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
56+
with monkeypatch.context() as m:
57+
vllm_config = make_vllm_config()
58+
args = (torch.randn(10, 10),)
59+
expected = reference_fn(*args)
60+
with use_vllm_config(vllm_config):
61+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
62+
with (
63+
pytest.raises(RuntimeError, match="Detected recompile"),
64+
torch.compiler.set_stance("fail_on_recompile"),
65+
):
66+
CompiledMod(vllm_config=vllm_config)(*args)
67+
68+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
69+
torch._dynamo.reset()
70+
with torch.compiler.set_stance("fail_on_recompile"):
71+
actual = CompiledMod(vllm_config=vllm_config)(*args)
72+
assert torch.allclose(actual, expected)
73+
74+
75+
@pytest.mark.skipif(
76+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
77+
)
78+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
79+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
80+
args = (torch.randn(10, 10),)
81+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
82+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
83+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
84+
vllm_config = make_vllm_config()
85+
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
86+
CompiledMod(vllm_config=vllm_config)(*args)
87+
88+
89+
@pytest.mark.skipif(
90+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
91+
)
92+
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
93+
with monkeypatch.context() as m:
94+
args = (torch.randn(10, 10),)
95+
96+
with tempfile.TemporaryDirectory() as tmpdirname:
97+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
98+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
99+
vllm_config = make_vllm_config()
100+
with use_vllm_config(vllm_config):
101+
expected = CompiledMod(vllm_config=vllm_config)(*args)
102+
103+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
104+
vllm_config = make_vllm_config()
105+
with use_vllm_config(vllm_config):
106+
ret = CompiledMod(vllm_config=vllm_config)(*args)
107+
assert torch.allclose(ret, expected)
108+
109+
110+
@pytest.mark.skipif(
111+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
112+
)
113+
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
114+
"""
115+
Test that the shape environment is correctly serialized and preserved
116+
when loading from cache.
117+
"""
118+
with monkeypatch.context() as m:
119+
args = (torch.randn(10, 10),)
120+
121+
with tempfile.TemporaryDirectory() as tmpdirname:
122+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
123+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
124+
vllm_config = make_vllm_config()
125+
with use_vllm_config(vllm_config):
126+
compiled_mod = CompiledMod(vllm_config=vllm_config)
127+
compiled_mod(*args)
128+
artifacts = compiled_mod.aot_compiled_fn._artifacts
129+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
130+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
131+
132+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
133+
vllm_config = make_vllm_config()
134+
with use_vllm_config(vllm_config):
135+
compiled_mod = CompiledMod(vllm_config=vllm_config)
136+
compiled_mod(*args)
137+
artifacts = compiled_mod.aot_compiled_fn._artifacts
138+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
139+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

tools/pre_commit/check_pickle_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"vllm/multimodal/hasher.py",
2323
"vllm/transformers_utils/config.py",
2424
"vllm/model_executor/models/registry.py",
25+
"vllm/compilation/caching.py",
2526
"tests/utils_/test_utils.py",
2627
"tests/tokenization/test_cached_tokenizer.py",
2728
"vllm/distributed/utils.py",

vllm/compilation/backends.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ast
55
import dataclasses
6+
import hashlib
67
import os
78
import pprint
89
import time
@@ -25,6 +26,7 @@
2526
from vllm.platforms import current_platform
2627
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
2728

29+
from .caching import VllmSerializableFunction
2830
from .compiler_interface import (
2931
CompilerInterface,
3032
EagerAdaptor,
@@ -195,6 +197,7 @@ def compile(
195197
# there can be multiple graphs due to piecewise compilation.
196198
now = time.time()
197199
elapsed = now - compilation_start_time
200+
compilation_config.compilation_time += elapsed
198201
if runtime_shape is None:
199202
logger.info(
200203
"Directly load the compiled graph(s) for dynamic shape "
@@ -549,47 +552,23 @@ def configure_post_pass(self):
549552
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
550553
inductor_config[PASS_KEY] = self.post_grad_pass_manager
551554

552-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
555+
def __call__(
556+
self, graph: fx.GraphModule, example_inputs
557+
) -> VllmSerializableFunction:
558+
from .caching import _compute_code_hash, compilation_config_hash_factors
559+
553560
vllm_config = self.vllm_config
554561
if not self.compilation_config.cache_dir:
555562
# no provided cache dir, generate one based on the known factors
556563
# that affects the compilation. if none of the factors change,
557564
# the cache dir will be the same so that we can reuse the compiled
558565
# graph.
559566

560-
factors = []
561-
# 0. factors come from the env, for example, The values of
562-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
563-
env_hash = envs.compute_hash()
564-
factors.append(env_hash)
565-
566-
# 1. factors come from the vllm_config (it mainly summarizes how the
567-
# model is created)
568-
config_hash = vllm_config.compute_hash()
569-
factors.append(config_hash)
570-
567+
factors = compilation_config_hash_factors(vllm_config)
571568
# 2. factors come from the code files that are traced by Dynamo (
572569
# it mainly summarizes how the model is used in forward pass)
573-
forward_code_files = list(sorted(self.compilation_config.traced_files))
570+
code_hash = _compute_code_hash(self.compilation_config.traced_files)
574571
self.compilation_config.traced_files.clear()
575-
logger.debug(
576-
"Traced files (to be considered for compilation cache):\n%s",
577-
"\n".join(forward_code_files),
578-
)
579-
hash_content = []
580-
for filepath in forward_code_files:
581-
hash_content.append(filepath)
582-
if filepath == "<string>":
583-
# This means the function was dynamically generated, with
584-
# e.g. exec(). We can't actually check these.
585-
continue
586-
with open(filepath) as f:
587-
hash_content.append(f.read())
588-
import hashlib
589-
590-
code_hash = hashlib.md5(
591-
"\n".join(hash_content).encode(), usedforsecurity=False
592-
).hexdigest()
593572
factors.append(code_hash)
594573

595574
# 3. compiler hash
@@ -695,7 +674,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
695674
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
696675
or not self.compilation_config.cudagraph_copy_inputs
697676
):
698-
return self.split_gm
677+
return VllmSerializableFunction(
678+
graph, example_inputs, self.prefix, self.split_gm
679+
)
699680

700681
# if we need to copy input buffers for cudagraph
701682
from torch._guards import detect_fake_mode
@@ -740,4 +721,6 @@ def copy_and_call(*args):
740721
list_args[index] = static_tensor
741722
return self.split_gm(*list_args)
742723

743-
return copy_and_call
724+
return VllmSerializableFunction(
725+
graph, example_inputs, self.prefix, copy_and_call
726+
)

0 commit comments

Comments
 (0)