From 2caec1230b98d74087375ce3c101bc0c292c9afc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 08:59:02 +0800 Subject: [PATCH 1/8] collect all files Signed-off-by: youkaichao --- vllm/compilation/decorators.py | 17 ++++++++++++++++- vllm/config.py | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 805a217ee6ca..fed8103c5b0c 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,8 +1,10 @@ import inspect from typing import Callable, Dict, List, Optional, TypeVar, Union, overload +from unittest.mock import patch import torch import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher @@ -196,7 +198,20 @@ def __call__(self, *args, **kwargs): # we need to control all the compilation of the model. torch._dynamo.eval_frame.remove_from_cache( self.original_code_object) - return self.compiled_callable(*args, **kwargs) + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output # usually, capturing the model once is enough, and then we can # dispatch to the compiled code directly, without going through diff --git a/vllm/config.py b/vllm/config.py index 8e556743c852..db981dc73bc6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2751,6 +2751,7 @@ def model_post_init(self, __context: Any) -> None: # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + traced_files: Set[str] = PrivateAttr compilation_time: float = PrivateAttr # should be InductorHashCache, but Pydantic does not support it inductor_hash_cache: Any = PrivateAttr From 0782f1018ebecbe80698128f810a2dc7f14d7c8e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 09:02:35 +0800 Subject: [PATCH 2/8] init Signed-off-by: youkaichao --- vllm/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index db981dc73bc6..dc2c21419730 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2792,6 +2792,7 @@ def __repr__(self) -> str: "compilation_time", "bs_to_padded_graph_size", "pass_config", + "traced_files", } return self.model_dump_json(exclude=exclude, exclude_unset=True) @@ -2851,6 +2852,7 @@ def model_post_init(self, __context: Any) -> None: self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() + self.traced_files = set() self.static_forward_context = {} self.compilation_time = 0.0 From f2ef7a44fe501126d5c8c0a1d6cb1ce05babc882 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 10:26:09 +0800 Subject: [PATCH 3/8] move hash cache to vllm backend Signed-off-by: youkaichao --- vllm/compilation/backends.py | 60 +++++++++++++++++++++++++++++++++--- vllm/config.py | 24 --------------- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a8dd628b9cd6..0b4ed5eff00e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -354,7 +354,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): def __init__(self, module: torch.fx.GraphModule, compile_submod_names: List[str], vllm_config: VllmConfig, - graph_pool): + graph_pool, vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() @@ -362,6 +362,7 @@ def __init__(self, module: torch.fx.GraphModule, self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool self.vllm_config = vllm_config + self.vllm_backend = vllm_backend def run(self, *args): fake_args = [ @@ -397,7 +398,7 @@ def call_module(self, target: torch.fx.node.Target, self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_general_shape) + compiled_graph_for_general_shape, self.vllm_backend) compilation_counter.num_piecewise_capturable_graphs_seen += 1 @@ -472,6 +473,53 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + if not self.compilation_config.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + vllm_config = self.vllm_config + config_hash = vllm_config.compute_hash() + + # 2. factors come from the code files that are traced by Dynamo ( + # it mainly summarizes how the model is used in forward pass) + forward_code_files = list( + sorted(self.compilation_config.traced_files)) + self.compilation_config.traced_files.clear() + logger.debug( + "Traced files: %s (to be considered for compilation cache)", + forward_code_files) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + with open(filepath) as f: + hash_content.append(f.read()) + import hashlib + code_hash = hashlib.md5( + "\n".join(hash_content).encode()).hexdigest() + + # combine the two hashes to generate the cache dir + hash_key = hashlib.md5( + f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, + f"rank_{vllm_config.parallel_config.rank}") + else: + cache_dir = self.compilation_config.cache_dir + os.makedirs(cache_dir, exist_ok=True) + + disabled = envs.VLLM_DISABLE_COMPILE_CACHE + self.inductor_hash_cache: InductorHashCache = InductorHashCache( + cache_dir, disabled=disabled) + if disabled: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info("Using cache directory: %s for vLLM's torch.compile", + cache_dir) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 @@ -507,8 +555,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, - self.graph_pool).run(*example_inputs) + self.vllm_config, self.graph_pool, + self).run(*example_inputs) self._called = True @@ -577,7 +625,8 @@ class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, graph_pool: Any, piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: List[int], - compiled_graph_for_general_shape: Callable): + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): """ The backend for piecewise compilation. It mainly handles the compilation and cudagraph capturing. @@ -597,6 +646,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 self.is_last_graph = ( diff --git a/vllm/config.py b/vllm/config.py index ca263f080bb1..42af9628258d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,7 +3,6 @@ import enum import hashlib import json -import os import warnings from contextlib import contextmanager from dataclasses import dataclass, field, replace @@ -2881,29 +2880,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE - if not self.cache_dir: - # no provided cache dir, generate one based on the known factors - # that affects the compilation. if none of the factors change, - # the cache dir will be the same so that we can reuse the compiled - # graph. - hash_key = vllm_config.compute_hash() - cache_dir = os.path.join( - envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, - f"rank_{vllm_config.parallel_config.rank}") - os.makedirs(cache_dir, exist_ok=True) - self.cache_dir = cache_dir - - disabled = envs.VLLM_DISABLE_COMPILE_CACHE - from vllm.compilation.backends import InductorHashCache - self.inductor_hash_cache: InductorHashCache = InductorHashCache( - self.cache_dir, disabled=disabled) - if disabled: - logger.info("vLLM's torch.compile cache is disabled.") - else: - logger.info( - "Using cache directory: %s for vLLM's torch.compile", - self.cache_dir) - from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) From 387747cc244dd74dd4bb8ae3b95cfe026f0ed9d2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 12:52:32 +0800 Subject: [PATCH 4/8] finish Signed-off-by: youkaichao --- vllm/compilation/backends.py | 3 ++- vllm/config.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0b4ed5eff00e..dc2e6a462812 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -431,6 +431,7 @@ class VllmBackend: post_grad_passes: Sequence[Callable] sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] + inductor_hash_cache: InductorHashCache def __init__( self, @@ -684,7 +685,7 @@ def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run - self.compilation_config.inductor_hash_cache.save_to_file() + self.vllm_backend.inductor_hash_cache.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: diff --git a/vllm/config.py b/vllm/config.py index 42af9628258d..6c539013d2f3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2758,8 +2758,6 @@ def model_post_init(self, __context: Any) -> None: disabled_custom_ops: Counter[str] = PrivateAttr traced_files: Set[str] = PrivateAttr compilation_time: float = PrivateAttr - # should be InductorHashCache, but Pydantic does not support it - inductor_hash_cache: Any = PrivateAttr # Per-model forward context # Mainly used to store attention cls From 17f8648a766e4b99072aa15dcbfa2339e92cb8a1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 13:00:17 +0800 Subject: [PATCH 5/8] fix Signed-off-by: youkaichao --- vllm/compilation/backends.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index dc2e6a462812..3541e01ada50 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule, example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + vllm_backend: "VllmBackend", graph_index: int = 0, num_graphs: int = 1, runtime_shape: Optional[int] = None, @@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule, # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - cache_data = compilation_config.inductor_hash_cache + cache_data = vllm_backend.inductor_hash_cache if (runtime_shape, graph_index) in cache_data: # we compiled this graph before # so we can directly lookup the compiled graph via hash @@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule, hash_str, example_inputs, True, False) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" - f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa + f"the cache file {cache_data.cache_file_path} and try again." # noqa ) # Inductor calling convention (function signature): @@ -390,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target, args, self.compilation_config.inductor_compile_config, self.compilation_config, + self.vllm_backend, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None, @@ -713,6 +715,7 @@ def __call__(self, *args) -> Any: args, self.compilation_config.inductor_compile_config, self.compilation_config, + self.vllm_backend, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, From a8eaa491472a79e23d9fd0654848540b98fbb8ff Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 13:04:10 +0800 Subject: [PATCH 6/8] pprint Signed-off-by: youkaichao --- vllm/compilation/backends.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3541e01ada50..87655530cead 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -493,8 +493,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: sorted(self.compilation_config.traced_files)) self.compilation_config.traced_files.clear() logger.debug( - "Traced files: %s (to be considered for compilation cache)", - forward_code_files) + "Traced files (to be considered for compilation cache):\n%s", + "\n".join(forward_code_files)) hash_content = [] for filepath in forward_code_files: hash_content.append(filepath) From 6baccf5a4debed45be5ce442f4619ee45ca803b0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 14:01:35 +0800 Subject: [PATCH 7/8] fix pp Signed-off-by: youkaichao --- vllm/sequence.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/sequence.py b/vllm/sequence.py index 34f910d47b7d..dafd0582b1ef 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1096,6 +1096,13 @@ class IntermediateTensors: tensors: Dict[str, torch.Tensor] + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + def __getitem__(self, key: Union[str, slice]): if isinstance(key, str): return self.tensors[key] From b8b7217223d9e0fd734831a8fdf7fffdabb73946 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jan 2025 16:00:10 +0800 Subject: [PATCH 8/8] add comments Signed-off-by: youkaichao --- vllm/compilation/decorators.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fed8103c5b0c..10513111ea7f 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -198,8 +198,19 @@ def __call__(self, *args, **kwargs): # we need to control all the compilation of the model. torch._dynamo.eval_frame.remove_from_cache( self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function self.vllm_config.compilation_config.traced_files.add( self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files inline_call = InliningInstructionTranslator.inline_call def patched_inline_call(parent, func, args, kwargs):