Skip to content

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Sep 4, 2025

This PR consists of two parts/commits, to make it easier to review the overall change being added from aot precompile.

  1. The first commit add a new option VLLM_USE_AOT_COMPILE which will make torch.compile wrapper to always use torch.compile().aot_compile(inputs) in the first run, and reuse the aot compiled function in subsequent runs.
  2. The second commit provides a basic implementation of VLLM's custom compiler backend to be plugged into aot compilation serialization. We are not introducing new compilation behavior here because we simply stores the dynamo graph and example inputs, on load we just rerun vllm backend again in hope the compilation has been cached. In the future we plan to increase the serialization coverage so that we can always store backend artifacts as part of the package.

Overall the change should be orthogonal to other changes we are doing for dynamo, aot dispatcher and inductor, because

  • dynamo <-> vllm backend surface is stable
  • aot dispatcher <-> inductor surface is internal to this PR, i.e. they can be treated as hidden in blackbox and not affecting the work in this PR.

To not interfere with the existing workflow, we create a new cache directory torch_aot_compile and stores aot compiled artifacts there, in the future torch_aot_compile should contain all the compiled artifacts but right now we require both torch_aot_compile and torch_compile_cache to be present to avoid recompilation. We plan to gradually migrate the contents into torch_aot_compile in the long term.

Purpose

Add AOT compilation workflow for torch.compile without changing the existing caching behavior.

Mechanically how it works:

  1. We hook this into supports_torch_compile decorator layer, so that it intercepts call to model's forward function directly.
  2. Check if we have an AOT compiled function already in memory, if so, use that.
  3. If there's no AOT compiled function in memory, we calculate a cache key based on vllm config + model forward name, and try to load an AOT compiled function into memory. (this is a different key from the one in torch_compile_cache which has access to traced source files, but that's not present in AOT workflow)
  4. If we don't have any AOT compiled function available to use, just kick off AOT compilation and save it to disk for future use.

Essentially, 2. and 3. should be the warm start paths, and 4. is the cold start path.

Test Plan

tests/test_aot_compile.py

Test Result

==================================================================================== test session starts =====================================================================================
platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0
rootdir: /data/users/zhxchen17/vllm
configfile: pyproject.toml
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.9.0
collected 3 items                                                                                                                                                                            

tests/compile/test_aot_compile.py ...                                                                                                                                                  [100%]

====================================================================================== warnings summary ======================================================================================
tests/compile/test_aot_compile.py: 12 warnings
  /data/users/zhxchen17/pytorch/torch/fx/_graph_pickler.py:124: DeprecationWarning: Pickle, copy, and deepcopy support will be removed from itertools in Python 3.14.
    pickler.dump(obj)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================== 3 passed, 12 warnings in 54.35s ===============================================================================


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Ahead-Of-Time (AOT) compilation for torch.compile, aiming to improve performance by caching compiled artifacts. The changes include new environment variables to control AOT compilation, a custom serializable compiled function class, and updates to the compilation decorators and wrappers. My review has identified a critical issue with cache isolation that could lead to corruption, a high-severity issue with the cache key hashing strategy that could cause stale cache hits, and a high-severity maintainability concern due to heavy reliance on internal PyTorch APIs. Addressing these points will improve the robustness and long-term stability of this new feature.

Comment on lines 403 to 472
class VllmCompiledFunction(SerializableCallable):

def __init__(self, graph_module, example_inputs, vllm_config, prefix,
optimized_call):
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.vllm_config = vllm_config
self.prefix = prefix
self.optimized_call = optimized_call

def __call__(self, *args, **kwargs):
return self.optimized_call(*args, **kwargs)

@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmCompiledFunction") -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)

graph_reducer_override = GraphPickler.reducer_override

def _graph_reducer_override(self, obj):
if (inspect.isclass(obj) and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_unpickler")):
return obj._torch_unpickler, (obj._torch_handler_name, )
if isinstance(obj, FakeTensorMode):
return type(None), ()
return graph_reducer_override(self, obj)

with patch.object(GraphPickler, 'reducer_override',
_graph_reducer_override):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None))
state["example_inputs"] = GraphPickler.dumps(
state["example_inputs"])
return pickle.dumps(state)

@classmethod
def deserialize_compile_artifacts(cls,
data: bytes) -> "VllmCompiledFunction":
from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv

state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"],
fake_mode)
state["example_inputs"] = GraphPickler.loads(state["example_inputs"],
fake_mode)
vllm_backend = VllmBackend(state["vllm_config"], state["prefix"])
with tracing(TracingContext(fake_mode)):
optimized_call = vllm_backend(state["graph_module"],
state["example_inputs"])

return cls(
state["graph_module"],
state["example_inputs"],
state["vllm_config"],
state["prefix"],
optimized_call,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of VllmCompiledFunction relies heavily on internal PyTorch APIs (e.g., torch._dynamo.aot_compile, torch.fx._graph_pickler, torch._subclasses). While this might be necessary for this feature, it creates a significant maintenance burden. These APIs are not guaranteed to be stable and can change without notice in future PyTorch releases, which could break this functionality. It would be good to add comments explaining why each internal API is used and potentially explore ways to reduce this dependency if possible in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be happy to add some comments to show what these APIs are. IMO the function names should be self evident, and I do consider these parts to be relatively stable in torch.

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch 2 times, most recently from a4662c8 to 2124e79 Compare September 4, 2025 20:06
@vadiklyutiy
Copy link
Collaborator

@zhxchen17
Do I understand correctly that the goal is to avoid/deprecate the use of internal (and possibly unstable) torch interfaces that we currently rely on for compilation?

Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work! Generally PR looks good to me. Left some minor comments.

aot_compilation_path = os.path.join(cache_dir, "model")
try:
with open(aot_compilation_path, "rb") as f:
aot_compiled_fn = torch.compiler.load_compiled_function(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an internal verification on cuda version, hardware in torch.compile loading?

Copy link
Contributor Author

@zhxchen17 zhxchen17 Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we haven't implemented it yet, since this feature is relatively new (why we start with opt-in). Here is a list of guards we plan to implement for loading:

  • torch version
  • python version
  • cuda version
  • hardware
  • traced source files

Let me know if there's anything I'm missing, thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardware check landed on torch side pytorch/pytorch#162438

@zhxchen17
Copy link
Contributor Author

@zhxchen17 Do I understand correctly that the goal is to avoid/deprecate the use of internal (and possibly unstable) torch interfaces that we currently rely on for compilation?

@vadiklyutiy Yes, that's part of our goal. Overall we are changing the usage of torch.compile from JIT mode to AOT mode, and the major benefit will be reduced warm start time for torch.compile() (since we skip dynamo as well in the second run). One side effect of the work should be more clear and stable boundary between torch compiler and vllm's custom backend.

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch 5 times, most recently from c35b8ae to c396bb7 Compare September 10, 2025 20:13
Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me overall. I need a bit of time to think about how this fits into the CompilerManager and CompilerInterface abstraction. My initial gut reaction is that this is something completely separate

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think I'm fine on the structure (this PR adds something that is different from CompileManager/CompilerInterface), so mostly minor comments. We can always figure out the right abstraction for this (if it needs abstraction) later.

Mostly some questions/comments about code reuse, the cache directory structure, and what exactly aot_compile returns

@zhxchen17 zhxchen17 requested a review from zou3519 September 11, 2025 17:06
@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from c396bb7 to beccd65 Compare September 11, 2025 21:22
@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from beccd65 to 73b971e Compare September 15, 2025 19:47
@zhxchen17
Copy link
Contributor Author

Benchmark result (regarding cold start vs warm start):

Test environment: Nvidia 8xB200 node, pytorch 2.10 main branch + cuda 12.9
Test script: https://gist.github.com/zhxchen17/75ad6c2576794607ee2cd2ff6e421b9e

nvidia/Llama-3.3-70B-Instruct-FP8 (TP=2)

Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 62.86 s in total
Cold start (VLLM_USE_AOT_COMPILE=0): [monitor.py:32] torch.compile takes 69.35 s in total
Warm start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 10.68 s in total
Warm start (VLLM_USE_AOT_COMPILE=0): [monitor.py:34] torch.compile takes 17.61 s in total

Qwen/Qwen3-32B

Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 37.62 s in total
Cold start (VLLM_USE_AOT_COMPILE=0): [monitor.py:32] torch.compile takes 37.64 s in total
Warm start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 8.09 s in total
Warm start (VLLM_USE_AOT_COMPILE=0): [monitor.py:34] torch.compile takes 10.86 s in total

deepseek-ai/DeepSeek-V3.1 (TP=8)

Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 73.92 s in total
Cold start (VLLM_USE_AOT_COMPILE=0): [monitor.py:32] torch.compile takes 67.44 s in total
Warm start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 7.38 s in total
Warm start (VLLM_USE_AOT_COMPILE=0): [monitor.py:34] torch.compile takes 11.54 s in total

openai/gpt-oss-120b (TP=2)

Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 33.51 s in total
Cold start (VLLM_USE_AOT_COMPILE=0): [monitor.py:32] torch.compile takes 37.27 s in total
Warm start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 3.53 s in total
Warm start (VLLM_USE_AOT_COMPILE=0): [monitor.py:34] torch.compile takes 5.73 s in total

zai-org/GLM-4.5-Air (TP=2)

Cold start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 41.60 s in total
Cold start (VLLM_USE_AOT_COMPILE=0): [monitor.py:32] torch.compile takes 45.74 s in total
Warm start (VLLM_USE_AOT_COMPILE=1): [monitor.py:34] torch.compile takes 4.69 s in total
Warm start (VLLM_USE_AOT_COMPILE=0): [monitor.py:34] torch.compile takes 7.49 s in total

@zhxchen17
Copy link
Contributor Author

Updates:

  • Enabled VLLM_USE_AOT_COMPILE=1 for torch>=2.10
  • Added some benchmark re cold start and warm start.
  • Rebased on main

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from 6d27095 to 955e518 Compare October 6, 2025 15:45
@zhxchen17
Copy link
Contributor Author

rebased

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall!


m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with use_vllm_config(vllm_config), torch.compiler.set_stance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why duplicate use_vllm_config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch. I think it's by accident and I will remove this.

torch._dynamo.reset()
with use_vllm_config(vllm_config), torch.compiler.set_stance(
"fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't this fail - where does the compiled code come from? Does the previous run that raised a recompile error create it? Or does it come from the cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name of API torch.compiler.set_stance('fail_on_recompile') is the source of confusion here. Basically torch.compile() has 2 modes now: JIT and AOT. torch.compiler.set_stance('fail_on_recompile') means torch.compile() will fail when we recompile in JIT mode.

Here by setting VLLM_USE_AOT_COMPILE=1, we're testing that torch.compile() JIT mode is not triggered. We're not testing the loading behavior in this unit test yet (we'll test loading part in the following tests). In other words, we are just testing we're using the correct AOT compile API from torch.

I think it's possible to address this by naming our API to be something like set_stance("fail_on_new_cache_entry") or better, but the behavior here is just about JIT vs AOT.



class VllmSerializableFunction(SerializableCallable):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some more comments on what this does, not just why it's needed? Devs will navigate here from the use and should be informed that this is mostly a wrapper around graph_module so they can just skip through it if they're not interested in the serialization

state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])

def optimized_call(*example_inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The control flow here seems a bit complex, could we add a comment or two?

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from 955e518 to de876f4 Compare October 10, 2025 16:05
@zhxchen17
Copy link
Contributor Author

Added comments and rebased.

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 10, 2025
@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from de876f4 to 7935089 Compare October 10, 2025 16:35
@zou3519 zou3519 enabled auto-merge (squash) October 10, 2025 16:43
auto-merge was automatically disabled October 10, 2025 18:49

Head branch was pushed to by a user without write access

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from 31bca0d to df40fe6 Compare October 10, 2025 18:49
@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/2 branch from df40fe6 to aba7a85 Compare October 10, 2025 19:16
@zou3519 zou3519 merged commit eef921f into vllm-project:main Oct 10, 2025
46 checks passed
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer

default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 From this line, it makes sense that the logic here runs on PyTorch CI when we tests against PyTorch main branch. So, there are a couple of failures there https://github.com/pytorch/pytorch/actions/runs/18522236183/job/52791051622 blocking the vLLM commit pin update

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zhxchen17 these tests are failing on PyTorch main, can you take a look please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can take a look @zou3519 @huydhn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch/pytorch#165702 should partially fix the test failures on main branch. I will do a full test on test_basic_correctness and report back.

Copy link
Contributor

@huydhn huydhn Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need to test your fix on PyTorch side, you could bump the pinned vLLM commit we have there https://github.com/pytorch/pytorch/blob/main/.github/ci_commit_pins/vllm.txt to a recent one on your PR, then add ciflow/vllm to run the tests on vLLM x PyTorch main

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. I found 3 issues on vllm side and working on fixing them:
#27285
#27288
#27350

Once they are merged (should be all around 1-2 lines of minor change), I will bump vllm pin to the latest and send a pytorch PR with label ciflow/vllm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all fixes are landed for now. trying to update pin with pytorch/pytorch#166494

bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants