Skip to content

Commit 6581378

Browse files
authored
[torch.compile] add warning for unsupported models (#10622)
Signed-off-by: youkaichao <[email protected]>
1 parent 7c2134b commit 6581378

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

vllm/compilation/counter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
@dataclasses.dataclass
77
class CompilationCounter:
8+
num_models_seen: int = 0
89
num_graphs_seen: int = 0
910
# including the splitting ops
1011
num_piecewise_graphs_seen: int = 0

vllm/compilation/decorators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm.compilation.counter import compilation_counter
67
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
78
from vllm.config import CompilationLevel, VllmConfig
89
from vllm.logger import init_logger
@@ -130,6 +131,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
130131
] or not supports_dynamo()
131132
if self.do_not_compile:
132133
return
134+
compilation_counter.num_models_seen += 1
133135
TorchCompileWrapperWithCustomDispatcher.__init__(
134136
self, compilation_level=vllm_config.compilation_config.level)
135137

vllm/plugins/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
8080
"""
8181
global _current_vllm_config
8282
old_vllm_config = _current_vllm_config
83+
from vllm.compilation.counter import compilation_counter
84+
from vllm.config import CompilationLevel
85+
num_models_seen = compilation_counter.num_models_seen
8386
try:
8487
_current_vllm_config = vllm_config
8588
yield
@@ -88,6 +91,18 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
8891
vllm_config.compilation_config.enabled_custom_ops)
8992
logger.debug("disabled custom ops: %s",
9093
vllm_config.compilation_config.disabled_custom_ops)
94+
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
95+
and compilation_counter.num_models_seen == num_models_seen:
96+
# If the model supports compilation,
97+
# compilation_counter.num_models_seen should be increased
98+
# by at least 1.
99+
# If it is not increased, it means the model does not support
100+
# compilation (does not have @support_torch_compile decorator).
101+
logger.warning(
102+
"`torch.compile` is turned on, but the model %s"
103+
" does not support it. Please open an issue on GitHub"
104+
"if you want it to be supported.",
105+
vllm_config.model_config.model)
91106
_current_vllm_config = old_vllm_config
92107

93108

0 commit comments

Comments
 (0)