Skip to content

Commit 6d27095

Browse files
committed
AOT compilation workflow [2/n]
Signed-off-by: zhxchen17 <[email protected]>
1 parent 6fc2967 commit 6d27095

File tree

8 files changed

+368
-92
lines changed

8 files changed

+368
-92
lines changed

tests/compile/test_aot_compile.py

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import tempfile
45
from contextlib import contextmanager
56

67
import pytest
@@ -12,15 +13,22 @@
1213
from vllm.forward_context import set_forward_context
1314

1415

15-
class MyMod(torch.nn.Module):
16+
def reference_fn(x: torch.Tensor):
17+
assert x.shape[0] <= 42
18+
assert x.shape[0] % 2 == 0
19+
for _ in range(3000):
20+
x = x + x.shape[0]
21+
return x
22+
23+
24+
@support_torch_compile
25+
class CompiledMod(torch.nn.Module):
1626

1727
def __init__(self, **kwargs):
1828
super().__init__()
1929

2030
def forward(self, x: torch.Tensor):
21-
for _ in range(3000):
22-
x = x + x.shape[0]
23-
return x
31+
return reference_fn(x)
2432

2533

2634
def make_vllm_config() -> VllmConfig:
@@ -30,32 +38,84 @@ def make_vllm_config() -> VllmConfig:
3038

3139
@contextmanager
3240
def use_vllm_config(vllm_config: VllmConfig):
33-
with set_forward_context(
34-
{}, vllm_config), set_current_vllm_config(vllm_config):
41+
with set_forward_context({}, vllm_config), \
42+
set_current_vllm_config(vllm_config):
3543
yield
3644

3745

38-
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
46+
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
3947
with monkeypatch.context() as m:
40-
mod = MyMod()
48+
vllm_config = make_vllm_config()
4149
args = (torch.randn(10, 10), )
42-
expected = mod(*args)
43-
CompiledMod = support_torch_compile(MyMod)
50+
expected = reference_fn(*args)
51+
with use_vllm_config(vllm_config):
52+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
53+
with pytest.raises(RuntimeError, match="Detected recompile"), \
54+
torch.compiler.set_stance("fail_on_recompile"):
55+
CompiledMod(vllm_config=vllm_config)(*args)
4456

45-
vllm_config = make_vllm_config()
46-
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47-
try:
57+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
58+
torch._dynamo.reset()
4859
with use_vllm_config(vllm_config), torch.compiler.set_stance(
4960
"fail_on_recompile"):
50-
CompiledMod(vllm_config=vllm_config)(*args)
51-
except RuntimeError as e:
52-
assert "Detected recompile" in str(e)
53-
else:
54-
raise AssertionError("Expected exception to be raised")
61+
actual = CompiledMod(vllm_config=vllm_config)(*args)
62+
assert torch.allclose(actual, expected)
5563

64+
65+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
66+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context(
67+
) as m:
68+
args = (torch.randn(10, 10), )
5669
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57-
torch._dynamo.reset()
58-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59-
"fail_on_recompile"):
60-
ret = CompiledMod(vllm_config=vllm_config)(*args)
70+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
71+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
72+
vllm_config = make_vllm_config()
73+
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
74+
CompiledMod(vllm_config=vllm_config)(*args)
75+
76+
77+
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
78+
with monkeypatch.context() as m:
79+
args = (torch.randn(10, 10), )
80+
81+
with tempfile.TemporaryDirectory() as tmpdirname:
82+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
83+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
84+
vllm_config = make_vllm_config()
85+
with use_vllm_config(vllm_config):
86+
expected = CompiledMod(vllm_config=vllm_config)(*args)
87+
88+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
89+
vllm_config = make_vllm_config()
90+
with use_vllm_config(vllm_config):
91+
ret = CompiledMod(vllm_config=vllm_config)(*args)
6192
assert torch.allclose(ret, expected)
93+
94+
95+
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
96+
"""
97+
Test that the shape environment is correctly serialized and preserved
98+
when loading from cache.
99+
"""
100+
with monkeypatch.context() as m:
101+
args = (torch.randn(10, 10), )
102+
103+
with tempfile.TemporaryDirectory() as tmpdirname:
104+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
105+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
106+
vllm_config = make_vllm_config()
107+
with use_vllm_config(vllm_config):
108+
compiled_mod = CompiledMod(vllm_config=vllm_config)
109+
compiled_mod(*args)
110+
artifacts = compiled_mod.aot_compiled_fn._artifacts
111+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
112+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
113+
114+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
115+
vllm_config = make_vllm_config()
116+
with use_vllm_config(vllm_config):
117+
compiled_mod = CompiledMod(vllm_config=vllm_config)
118+
compiled_mod(*args)
119+
artifacts = compiled_mod.aot_compiled_fn._artifacts
120+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
121+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

tools/pre_commit/check_pickle_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'vllm/multimodal/hasher.py',
2323
'vllm/transformers_utils/config.py',
2424
'vllm/model_executor/models/registry.py',
25+
"vllm/compilation/caching.py",
2526
'tests/utils_/test_utils.py',
2627
'tests/tokenization/test_cached_tokenizer.py',
2728
'vllm/distributed/utils.py',

vllm/compilation/backends.py

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ast
55
import dataclasses
6+
import hashlib
67
import os
78
import pprint
89
import time
@@ -20,6 +21,7 @@
2021
from vllm.platforms import current_platform
2122
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
2223

24+
from .caching import VllmSerializableFunction
2325
from .compiler_interface import (CompilerInterface, EagerAdaptor,
2426
InductorAdaptor, InductorStandaloneAdaptor)
2527
from .counter import compilation_counter
@@ -160,6 +162,7 @@ def compile(self,
160162
# there can be multiple graphs due to piecewise compilation.
161163
now = time.time()
162164
elapsed = now - compilation_start_time
165+
compilation_config.compilation_time += elapsed
163166
if runtime_shape is None:
164167
logger.info(
165168
"Directly load the compiled graph(s) for dynamic shape "
@@ -398,35 +401,6 @@ def set_model_tag(tag: str):
398401
model_tag = old_tag
399402

400403

401-
try:
402-
from torch._dynamo.aot_compile import SerializableCallable
403-
except ImportError:
404-
SerializableCallable = object
405-
406-
assert isinstance(SerializableCallable, type)
407-
408-
409-
class VllmCompiledFunction(SerializableCallable):
410-
411-
def __init__(self, graph_module, example_inputs, vllm_config,
412-
optimized_call):
413-
self.graph_module = graph_module
414-
self.example_inputs = example_inputs
415-
self.vllm_config = vllm_config
416-
self.optimized_call = optimized_call
417-
418-
def __call__(self, *args, **kwargs):
419-
return self.optimized_call(*args, **kwargs)
420-
421-
@classmethod
422-
def serialize_compile_artifacts(cls, compiled_fn):
423-
raise NotImplementedError("serialization not implemented")
424-
425-
@classmethod
426-
def deserialize_compile_artifacts(cls, data):
427-
raise NotImplementedError("deserialization not implemented")
428-
429-
430404
class VllmBackend:
431405
"""The compilation backend for `torch.compile` with vLLM.
432406
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -502,7 +476,11 @@ def configure_post_pass(self):
502476
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
503477
inductor_config[PASS_KEY] = self.post_grad_pass_manager
504478

505-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
479+
def __call__(self, graph: fx.GraphModule,
480+
example_inputs) -> VllmSerializableFunction:
481+
482+
from .caching import (_compute_code_hash,
483+
compilation_config_hash_factors)
506484

507485
vllm_config = self.vllm_config
508486
if not self.compilation_config.cache_dir:
@@ -511,37 +489,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
511489
# the cache dir will be the same so that we can reuse the compiled
512490
# graph.
513491

514-
factors = []
515-
# 0. factors come from the env, for example, The values of
516-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
517-
env_hash = envs.compute_hash()
518-
factors.append(env_hash)
519-
520-
# 1. factors come from the vllm_config (it mainly summarizes how the
521-
# model is created)
522-
config_hash = vllm_config.compute_hash()
523-
factors.append(config_hash)
524-
492+
factors = compilation_config_hash_factors(vllm_config)
525493
# 2. factors come from the code files that are traced by Dynamo (
526494
# it mainly summarizes how the model is used in forward pass)
527-
forward_code_files = list(
528-
sorted(self.compilation_config.traced_files))
495+
code_hash = _compute_code_hash(
496+
self.compilation_config.traced_files)
529497
self.compilation_config.traced_files.clear()
530-
logger.debug(
531-
"Traced files (to be considered for compilation cache):\n%s",
532-
"\n".join(forward_code_files))
533-
hash_content = []
534-
for filepath in forward_code_files:
535-
hash_content.append(filepath)
536-
if filepath == "<string>":
537-
# This means the function was dynamically generated, with
538-
# e.g. exec(). We can't actually check these.
539-
continue
540-
with open(filepath) as f:
541-
hash_content.append(f.read())
542-
import hashlib
543-
code_hash = hashlib.md5("\n".join(hash_content).encode(),
544-
usedforsecurity=False).hexdigest()
498+
545499
factors.append(code_hash)
546500

547501
# 3. compiler hash
@@ -634,8 +588,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
634588

635589
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
636590
not self.compilation_config.cudagraph_copy_inputs:
637-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
638-
self.split_gm)
591+
return VllmSerializableFunction(graph, example_inputs, self.prefix,
592+
self.split_gm)
639593

640594
# if we need to copy input buffers for cudagraph
641595
from torch._guards import detect_fake_mode
@@ -677,5 +631,5 @@ def copy_and_call(*args):
677631
list_args[index] = static_tensor
678632
return self.split_gm(*list_args)
679633

680-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
681-
copy_and_call)
634+
return VllmSerializableFunction(graph, example_inputs, self.prefix,
635+
copy_and_call)

0 commit comments

Comments
 (0)