Skip to content

Commit aba7a85

Browse files
committed
AOT compilation workflow [2/n]
Signed-off-by: zhxchen17 <[email protected]>
1 parent 208af72 commit aba7a85

File tree

8 files changed

+423
-114
lines changed

8 files changed

+423
-114
lines changed

tests/compile/test_aot_compile.py

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,139 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import tempfile
45
from contextlib import contextmanager
56

67
import pytest
78
import torch
89

910
from vllm.compilation.decorators import support_torch_compile
10-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
11-
set_current_vllm_config)
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationLevel,
14+
VllmConfig,
15+
set_current_vllm_config,
16+
)
1217
from vllm.forward_context import set_forward_context
18+
from vllm.utils import is_torch_equal_or_newer
1319

1420

15-
class MyMod(torch.nn.Module):
21+
def reference_fn(x: torch.Tensor):
22+
assert x.shape[0] <= 42
23+
assert x.shape[0] % 2 == 0
24+
for _ in range(3000):
25+
x = x + x.shape[0]
26+
return x
1627

28+
29+
@support_torch_compile
30+
class CompiledMod(torch.nn.Module):
1731
def __init__(self, **kwargs):
1832
super().__init__()
1933

2034
def forward(self, x: torch.Tensor):
21-
for _ in range(3000):
22-
x = x + x.shape[0]
23-
return x
35+
return reference_fn(x)
2436

2537

2638
def make_vllm_config() -> VllmConfig:
27-
return VllmConfig(compilation_config=CompilationConfig(
28-
level=CompilationLevel.PIECEWISE, ))
39+
return VllmConfig(
40+
compilation_config=CompilationConfig(
41+
level=CompilationLevel.PIECEWISE,
42+
)
43+
)
2944

3045

3146
@contextmanager
3247
def use_vllm_config(vllm_config: VllmConfig):
33-
with set_forward_context(
34-
{}, vllm_config), set_current_vllm_config(vllm_config):
48+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
3549
yield
3650

3751

38-
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
52+
@pytest.mark.skipif(
53+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
54+
)
55+
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
3956
with monkeypatch.context() as m:
40-
mod = MyMod()
41-
args = (torch.randn(10, 10), )
42-
expected = mod(*args)
43-
CompiledMod = support_torch_compile(MyMod)
44-
4557
vllm_config = make_vllm_config()
46-
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47-
try:
48-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
49-
"fail_on_recompile"):
58+
args = (torch.randn(10, 10),)
59+
expected = reference_fn(*args)
60+
with use_vllm_config(vllm_config):
61+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
62+
with (
63+
pytest.raises(RuntimeError, match="Detected recompile"),
64+
torch.compiler.set_stance("fail_on_recompile"),
65+
):
5066
CompiledMod(vllm_config=vllm_config)(*args)
51-
except RuntimeError as e:
52-
assert "Detected recompile" in str(e)
53-
else:
54-
raise AssertionError("Expected exception to be raised")
5567

68+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
69+
torch._dynamo.reset()
70+
with torch.compiler.set_stance("fail_on_recompile"):
71+
actual = CompiledMod(vllm_config=vllm_config)(*args)
72+
assert torch.allclose(actual, expected)
73+
74+
75+
@pytest.mark.skipif(
76+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
77+
)
78+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
79+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
80+
args = (torch.randn(10, 10),)
5681
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57-
torch._dynamo.reset()
58-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59-
"fail_on_recompile"):
60-
ret = CompiledMod(vllm_config=vllm_config)(*args)
82+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
83+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
84+
vllm_config = make_vllm_config()
85+
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
86+
CompiledMod(vllm_config=vllm_config)(*args)
87+
88+
89+
@pytest.mark.skipif(
90+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
91+
)
92+
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
93+
with monkeypatch.context() as m:
94+
args = (torch.randn(10, 10),)
95+
96+
with tempfile.TemporaryDirectory() as tmpdirname:
97+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
98+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
99+
vllm_config = make_vllm_config()
100+
with use_vllm_config(vllm_config):
101+
expected = CompiledMod(vllm_config=vllm_config)(*args)
102+
103+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
104+
vllm_config = make_vllm_config()
105+
with use_vllm_config(vllm_config):
106+
ret = CompiledMod(vllm_config=vllm_config)(*args)
61107
assert torch.allclose(ret, expected)
108+
109+
110+
@pytest.mark.skipif(
111+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
112+
)
113+
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
114+
"""
115+
Test that the shape environment is correctly serialized and preserved
116+
when loading from cache.
117+
"""
118+
with monkeypatch.context() as m:
119+
args = (torch.randn(10, 10),)
120+
121+
with tempfile.TemporaryDirectory() as tmpdirname:
122+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
123+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
124+
vllm_config = make_vllm_config()
125+
with use_vllm_config(vllm_config):
126+
compiled_mod = CompiledMod(vllm_config=vllm_config)
127+
compiled_mod(*args)
128+
artifacts = compiled_mod.aot_compiled_fn._artifacts
129+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
130+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
131+
132+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
133+
vllm_config = make_vllm_config()
134+
with use_vllm_config(vllm_config):
135+
compiled_mod = CompiledMod(vllm_config=vllm_config)
136+
compiled_mod(*args)
137+
artifacts = compiled_mod.aot_compiled_fn._artifacts
138+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
139+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

tools/pre_commit/check_pickle_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"vllm/multimodal/hasher.py",
2323
"vllm/transformers_utils/config.py",
2424
"vllm/model_executor/models/registry.py",
25+
"vllm/compilation/caching.py",
2526
"tests/utils_/test_utils.py",
2627
"tests/tokenization/test_cached_tokenizer.py",
2728
"vllm/distributed/utils.py",

vllm/compilation/backends.py

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ast
55
import dataclasses
6+
import hashlib
67
import os
78
import pprint
89
import time
@@ -25,6 +26,7 @@
2526
from vllm.platforms import current_platform
2627
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
2728

29+
from .caching import VllmSerializableFunction
2830
from .compiler_interface import (
2931
CompilerInterface,
3032
EagerAdaptor,
@@ -195,6 +197,7 @@ def compile(
195197
# there can be multiple graphs due to piecewise compilation.
196198
now = time.time()
197199
elapsed = now - compilation_start_time
200+
compilation_config.compilation_time += elapsed
198201
if runtime_shape is None:
199202
logger.info(
200203
"Directly load the compiled graph(s) for dynamic shape "
@@ -472,35 +475,6 @@ def set_model_tag(tag: str):
472475
model_tag = old_tag
473476

474477

475-
try:
476-
from torch._dynamo.aot_compile import SerializableCallable
477-
except ImportError:
478-
SerializableCallable = object
479-
480-
assert isinstance(SerializableCallable, type)
481-
482-
483-
class VllmCompiledFunction(SerializableCallable):
484-
485-
def __init__(self, graph_module, example_inputs, vllm_config,
486-
optimized_call):
487-
self.graph_module = graph_module
488-
self.example_inputs = example_inputs
489-
self.vllm_config = vllm_config
490-
self.optimized_call = optimized_call
491-
492-
def __call__(self, *args, **kwargs):
493-
return self.optimized_call(*args, **kwargs)
494-
495-
@classmethod
496-
def serialize_compile_artifacts(cls, compiled_fn):
497-
raise NotImplementedError("serialization not implemented")
498-
499-
@classmethod
500-
def deserialize_compile_artifacts(cls, data):
501-
raise NotImplementedError("deserialization not implemented")
502-
503-
504478
class VllmBackend:
505479
"""The compilation backend for `torch.compile` with vLLM.
506480
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -578,47 +552,23 @@ def configure_post_pass(self):
578552
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
579553
inductor_config[PASS_KEY] = self.post_grad_pass_manager
580554

581-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
555+
def __call__(
556+
self, graph: fx.GraphModule, example_inputs
557+
) -> VllmSerializableFunction:
558+
from .caching import _compute_code_hash, compilation_config_hash_factors
559+
582560
vllm_config = self.vllm_config
583561
if not self.compilation_config.cache_dir:
584562
# no provided cache dir, generate one based on the known factors
585563
# that affects the compilation. if none of the factors change,
586564
# the cache dir will be the same so that we can reuse the compiled
587565
# graph.
588566

589-
factors = []
590-
# 0. factors come from the env, for example, The values of
591-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
592-
env_hash = envs.compute_hash()
593-
factors.append(env_hash)
594-
595-
# 1. factors come from the vllm_config (it mainly summarizes how the
596-
# model is created)
597-
config_hash = vllm_config.compute_hash()
598-
factors.append(config_hash)
599-
567+
factors = compilation_config_hash_factors(vllm_config)
600568
# 2. factors come from the code files that are traced by Dynamo (
601569
# it mainly summarizes how the model is used in forward pass)
602-
forward_code_files = list(sorted(self.compilation_config.traced_files))
570+
code_hash = _compute_code_hash(self.compilation_config.traced_files)
603571
self.compilation_config.traced_files.clear()
604-
logger.debug(
605-
"Traced files (to be considered for compilation cache):\n%s",
606-
"\n".join(forward_code_files),
607-
)
608-
hash_content = []
609-
for filepath in forward_code_files:
610-
hash_content.append(filepath)
611-
if filepath == "<string>":
612-
# This means the function was dynamically generated, with
613-
# e.g. exec(). We can't actually check these.
614-
continue
615-
with open(filepath) as f:
616-
hash_content.append(f.read())
617-
import hashlib
618-
619-
code_hash = hashlib.md5(
620-
"\n".join(hash_content).encode(), usedforsecurity=False
621-
).hexdigest()
622572
factors.append(code_hash)
623573

624574
# 3. compiler hash
@@ -724,8 +674,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
724674
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
725675
or not self.compilation_config.cudagraph_copy_inputs
726676
):
727-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
728-
self.split_gm)
677+
return VllmSerializableFunction(
678+
graph, example_inputs, self.prefix, self.split_gm
679+
)
729680

730681
# if we need to copy input buffers for cudagraph
731682
from torch._guards import detect_fake_mode
@@ -770,5 +721,6 @@ def copy_and_call(*args):
770721
list_args[index] = static_tensor
771722
return self.split_gm(*list_args)
772723

773-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
774-
copy_and_call)
724+
return VllmSerializableFunction(
725+
graph, example_inputs, self.prefix, copy_and_call
726+
)

0 commit comments

Comments
 (0)