-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
AOT Compilation for torch.compile (Bundled) #24274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Ahead-Of-Time (AOT) compilation for torch.compile, aiming to improve performance by caching compiled artifacts. The changes include new environment variables to control AOT compilation, a custom serializable compiled function class, and updates to the compilation decorators and wrappers. My review has identified a critical issue with cache isolation that could lead to corruption, a high-severity issue with the cache key hashing strategy that could cause stale cache hits, and a high-severity maintainability concern due to heavy reliance on internal PyTorch APIs. Addressing these points will improve the robustness and long-term stability of this new feature.
vllm/compilation/backends.py
Outdated
| class VllmCompiledFunction(SerializableCallable): | ||
|
|
||
| def __init__(self, graph_module, example_inputs, vllm_config, prefix, | ||
| optimized_call): | ||
| assert isinstance(graph_module, torch.fx.GraphModule) | ||
| self.graph_module = graph_module | ||
| self.example_inputs = example_inputs | ||
| self.vllm_config = vllm_config | ||
| self.prefix = prefix | ||
| self.optimized_call = optimized_call | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| return self.optimized_call(*args, **kwargs) | ||
|
|
||
| @classmethod | ||
| def serialize_compile_artifacts( | ||
| cls, compiled_fn: "VllmCompiledFunction") -> bytes: | ||
| import sympy | ||
| from torch._subclasses import FakeTensorMode | ||
| from torch.fx._graph_pickler import GraphPickler, Options | ||
| state = compiled_fn.__dict__.copy() | ||
| state.pop("optimized_call") | ||
| for node in state["graph_module"].graph.nodes: | ||
| node.meta.pop("source_fn_stack", None) | ||
| node.meta.pop("nn_module_stack", None) | ||
|
|
||
| graph_reducer_override = GraphPickler.reducer_override | ||
|
|
||
| def _graph_reducer_override(self, obj): | ||
| if (inspect.isclass(obj) and issubclass(obj, sympy.Function) | ||
| and hasattr(obj, "_torch_unpickler")): | ||
| return obj._torch_unpickler, (obj._torch_handler_name, ) | ||
| if isinstance(obj, FakeTensorMode): | ||
| return type(None), () | ||
| return graph_reducer_override(self, obj) | ||
|
|
||
| with patch.object(GraphPickler, 'reducer_override', | ||
| _graph_reducer_override): | ||
| state["graph_module"] = GraphPickler.dumps( | ||
| state["graph_module"], Options(ops_filter=None)) | ||
| state["example_inputs"] = GraphPickler.dumps( | ||
| state["example_inputs"]) | ||
| return pickle.dumps(state) | ||
|
|
||
| @classmethod | ||
| def deserialize_compile_artifacts(cls, | ||
| data: bytes) -> "VllmCompiledFunction": | ||
| from torch._guards import TracingContext, tracing | ||
| from torch._subclasses import FakeTensorMode | ||
| from torch.fx._graph_pickler import GraphPickler | ||
| from torch.fx.experimental.symbolic_shapes import ShapeEnv | ||
|
|
||
| state = pickle.loads(data) | ||
| fake_mode = FakeTensorMode(shape_env=ShapeEnv()) | ||
| state["graph_module"] = GraphPickler.loads(state["graph_module"], | ||
| fake_mode) | ||
| state["example_inputs"] = GraphPickler.loads(state["example_inputs"], | ||
| fake_mode) | ||
| vllm_backend = VllmBackend(state["vllm_config"], state["prefix"]) | ||
| with tracing(TracingContext(fake_mode)): | ||
| optimized_call = vllm_backend(state["graph_module"], | ||
| state["example_inputs"]) | ||
|
|
||
| return cls( | ||
| state["graph_module"], | ||
| state["example_inputs"], | ||
| state["vllm_config"], | ||
| state["prefix"], | ||
| optimized_call, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of VllmCompiledFunction relies heavily on internal PyTorch APIs (e.g., torch._dynamo.aot_compile, torch.fx._graph_pickler, torch._subclasses). While this might be necessary for this feature, it creates a significant maintenance burden. These APIs are not guaranteed to be stable and can change without notice in future PyTorch releases, which could break this functionality. It would be good to add comments explaining why each internal API is used and potentially explore ways to reduce this dependency if possible in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be happy to add some comments to show what these APIs are. IMO the function names should be self evident, and I do consider these parts to be relatively stable in torch.
a4662c8 to
2124e79
Compare
|
@zhxchen17 |
ilmarkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work! Generally PR looks good to me. Left some minor comments.
vllm/compilation/decorators.py
Outdated
| aot_compilation_path = os.path.join(cache_dir, "model") | ||
| try: | ||
| with open(aot_compilation_path, "rb") as f: | ||
| aot_compiled_fn = torch.compiler.load_compiled_function(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an internal verification on cuda version, hardware in torch.compile loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now we haven't implemented it yet, since this feature is relatively new (why we start with opt-in). Here is a list of guards we plan to implement for loading:
- torch version
- python version
- cuda version
- hardware
- traced source files
Let me know if there's anything I'm missing, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardware check landed on torch side pytorch/pytorch#162438
@vadiklyutiy Yes, that's part of our goal. Overall we are changing the usage of torch.compile from JIT mode to AOT mode, and the major benefit will be reduced warm start time for torch.compile() (since we skip dynamo as well in the second run). One side effect of the work should be more clear and stable boundary between torch compiler and vllm's custom backend. |
c35b8ae to
c396bb7
Compare
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me overall. I need a bit of time to think about how this fits into the CompilerManager and CompilerInterface abstraction. My initial gut reaction is that this is something completely separate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think I'm fine on the structure (this PR adds something that is different from CompileManager/CompilerInterface), so mostly minor comments. We can always figure out the right abstraction for this (if it needs abstraction) later.
Mostly some questions/comments about code reuse, the cache directory structure, and what exactly aot_compile returns
c396bb7 to
beccd65
Compare
beccd65 to
73b971e
Compare
|
Benchmark result (regarding cold start vs warm start): Test environment: Nvidia 8xB200 node, pytorch 2.10 main branch + cuda 12.9 nvidia/Llama-3.3-70B-Instruct-FP8 (TP=2)Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 62.86 s in total Qwen/Qwen3-32BCold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 37.62 s in total deepseek-ai/DeepSeek-V3.1 (TP=8)Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 73.92 s in total openai/gpt-oss-120b (TP=2)Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 33.51 s in total zai-org/GLM-4.5-Air (TP=2)Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 41.60 s in total |
|
Updates:
|
6d27095 to
955e518
Compare
|
rebased |
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall!
tests/compile/test_aot_compile.py
Outdated
|
|
||
| m.setenv("VLLM_USE_AOT_COMPILE", "1") | ||
| torch._dynamo.reset() | ||
| with use_vllm_config(vllm_config), torch.compiler.set_stance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why duplicate use_vllm_config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch. I think it's by accident and I will remove this.
| torch._dynamo.reset() | ||
| with use_vllm_config(vllm_config), torch.compiler.set_stance( | ||
| "fail_on_recompile"): | ||
| actual = CompiledMod(vllm_config=vllm_config)(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't this fail - where does the compiled code come from? Does the previous run that raised a recompile error create it? Or does it come from the cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name of API torch.compiler.set_stance('fail_on_recompile') is the source of confusion here. Basically torch.compile() has 2 modes now: JIT and AOT. torch.compiler.set_stance('fail_on_recompile') means torch.compile() will fail when we recompile in JIT mode.
Here by setting VLLM_USE_AOT_COMPILE=1, we're testing that torch.compile() JIT mode is not triggered. We're not testing the loading behavior in this unit test yet (we'll test loading part in the following tests). In other words, we are just testing we're using the correct AOT compile API from torch.
I think it's possible to address this by naming our API to be something like set_stance("fail_on_new_cache_entry") or better, but the behavior here is just about JIT vs AOT.
|
|
||
|
|
||
| class VllmSerializableFunction(SerializableCallable): | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some more comments on what this does, not just why it's needed? Devs will navigate here from the use and should be informed that this is mostly a wrapper around graph_module so they can just skip through it if they're not interested in the serialization
| state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) | ||
| vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) | ||
|
|
||
| def optimized_call(*example_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The control flow here seems a bit complex, could we add a comment or two?
955e518 to
de876f4
Compare
|
Added comments and rebased. |
de876f4 to
7935089
Compare
Head branch was pushed to by a user without write access
31bca0d to
df40fe6
Compare
Signed-off-by: zhxchen17 <[email protected]>
Signed-off-by: zhxchen17 <[email protected]>
df40fe6 to
aba7a85
Compare
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: Dhruvil Bhatt <[email protected]>
| def use_aot_compile() -> bool: | ||
| from vllm.utils import is_torch_equal_or_newer | ||
|
|
||
| default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 From this line, it makes sense that the logic here runs on PyTorch CI when we tests against PyTorch main branch. So, there are a couple of failures there https://github.com/pytorch/pytorch/actions/runs/18522236183/job/52791051622 blocking the vLLM commit pin update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zhxchen17 these tests are failing on PyTorch main, can you take a look please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/pytorch#165702 should partially fix the test failures on main branch. I will do a full test on test_basic_correctness and report back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you need to test your fix on PyTorch side, you could bump the pinned vLLM commit we have there https://github.com/pytorch/pytorch/blob/main/.github/ci_commit_pins/vllm.txt to a recent one on your PR, then add ciflow/vllm to run the tests on vLLM x PyTorch main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all fixes are landed for now. trying to update pin with pytorch/pytorch#166494
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: bbartels <[email protected]>
Signed-off-by: zhxchen17 <[email protected]>
Signed-off-by: zhxchen17 <[email protected]>
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: 0xrushi <[email protected]>
Signed-off-by: zhxchen17 <[email protected]> Signed-off-by: 0xrushi <[email protected]>
Signed-off-by: zhxchen17 <[email protected]>
Signed-off-by: zhxchen17 <[email protected]>
This PR consists of two parts/commits, to make it easier to review the overall change being added from aot precompile.
Overall the change should be orthogonal to other changes we are doing for dynamo, aot dispatcher and inductor, because
To not interfere with the existing workflow, we create a new cache directory
torch_aot_compileand stores aot compiled artifacts there, in the futuretorch_aot_compileshould contain all the compiled artifacts but right now we require bothtorch_aot_compileandtorch_compile_cacheto be present to avoid recompilation. We plan to gradually migrate the contents intotorch_aot_compilein the long term.Purpose
Add AOT compilation workflow for torch.compile without changing the existing caching behavior.
Mechanically how it works:
supports_torch_compiledecorator layer, so that it intercepts call to model's forward function directly.Essentially, 2. and 3. should be the warm start paths, and 4. is the cold start path.
Test Plan
tests/test_aot_compile.py
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.