From a4f3329a90f52f637395ca3c14e70b3892984551 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Mar 2025 19:51:18 +0800 Subject: [PATCH 1/9] fix inductor Signed-off-by: youkaichao --- vllm/compilation/compiler_interface.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index ac0544ad6403..d280fdfbe0d6 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -155,6 +155,7 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False): triton_cache = os.path.join(cache_dir, "triton_cache") os.makedirs(triton_cache, exist_ok=True) os.environ["TRITON_CACHE_DIR"] = triton_cache + self.cache_dir = cache_dir def compile( self, @@ -200,7 +201,19 @@ def compile( def hijack_load(*args, **kwargs): inductor_compiled_graph = original_load(*args, **kwargs) nonlocal file_path - file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.cache_dir): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break return inductor_compiled_graph hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa From ab0d66a05ebd13350671fe27b78d9839c77725e0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Mar 2025 20:35:14 +0800 Subject: [PATCH 2/9] add doc Signed-off-by: youkaichao --- docs/source/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/index.md b/docs/source/index.md index a6806900cb3c..0bd8e12d088a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -157,6 +157,7 @@ design/multiprocessing :caption: V1 Design Documents :maxdepth: 2 +design/v1/torch_compile design/v1/prefix_caching design/v1/metrics ::: From 47b5cecfde10db3c2039d454c9f89e6498bde22b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Mar 2025 20:35:43 +0800 Subject: [PATCH 3/9] fix Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 131 +++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 docs/source/design/v1/torch_compile.md diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md new file mode 100644 index 000000000000..a82ad49881e5 --- /dev/null +++ b/docs/source/design/v1/torch_compile.md @@ -0,0 +1,131 @@ +# vLLM's `torch.compile` integration + +In vLLM's V1 architecture, `torch.compile` is enabled by default and is a critical part of the framework. This document gives a simple walk-through example to show how to understand the `torch.compile` usage. + +Throughout the example, we will run a common Llama model using v1, and turn on debug level logging to show all the details. The command to be used is `VLLM_USE_V1=1 VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B`. + +## Compilation Cache + +In the very verbose logs, we can see: + +``` +INFO 03-07 03:06:55 [backends.py:409] Using cache directory: ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0 for vLLM's torch.compile +``` + +vLLM will take all the available factors into consideration, and decide a directory to store all the compilation artifact. This means, you can directly copy the whole `~/.cache/vllm/torch_compile_cache` directory in your deployment scenario to save a great amount of compilation time, and hence accelerating the starting time of the vLLM instance. + +A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes. + +## Python Code Compilation + +In the very verbose logs, we can see: + +``` +DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function + +DEBUG 03-07 03:06:54 [backends.py:370] Traced files (to be considered for compilation cache): +DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/_dynamo/polyfills/builtins.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/container.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/module.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/attention/layer.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/communication_op.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/parallel_state.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/custom_op.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/activation.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/layernorm.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/linear.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/rotary_embedding.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/vocab_parallel_embedding.py +DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/models/llama.py + +DEBUG 03-07 03:07:07 [backends.py:462] Computation graph saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py +DEBUG 03-07 03:07:07 [wrapper.py:105] Dynamo transformed code saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py +``` + +This is about the Python code compilation, i.e. graph capture by Dynamo. It tries to trace the function with code `xxx/vllm/model_executor/models/llama.py:339`, which is the `forward` function of the model we compile. During the forward pass, there are also other functions called and inlined by Dynamo, as shown by the logs, including some PyTorch functions from `xxx/torch/nn/modules/module.py` (used by PyTorch `nn.Module`, because module attribute access will trigger a function call), some communication / attention / activation functions from vLLM. All the traced files will be considered when we decide the cache directory to use. This way, any code change in the above files will trigger compilation cache miss, and therefore recompilation. + +The result of the Dynamo compilation, is a new function stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py`. Usually, this function unpacks tensors from the module, and then pass it to the traced computation graph. The computation graph is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py`. + +## Computation Graph Processing + +The computation graph has shape annotations for every tensor. The inputs are input ids, position ids, weights and buffers from the model, and the outputs are the final hidden states. Note that lm head projection and sampling operations are not considered in the graph. + +Most of the inputs to the computation graph has static shape, since they are model weights and buffers, and will not change during the lifetime of the model. Only the input ids and position ids have symbolic shapes, i.e. the shape can change from batch to batch. However, they will share the same symbolic shapes. That is to say, the only changing size to the computation graph, is the batch size (number of tokens processed in the current forward pass). + +The attention operation is complicated, and it needs to interact with kv caches, with complicated shapes. Fortunately, the output of the attention operation just share the same shape as the input query of the attention operation. Therefore, we wrap the whole attention operation into a PyTorch custom op `torch.ops.vllm.unified_attention_with_output`, so that Dynamo will not try to inspect any of the internal operations. This way, although attention operation is complicated, we can still capture the model's computation graph as a full-graph, from Dynamo's perspective. + +The computation graph is further split into pieces, by the `splitting_ops` (usually this is the attention operation). Therefore, in the `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py` file, we can see lots of submodules, each submodule is a piece of graph after splitting: + +- Attention operation itself is a submodule. +- The part of computation graph, from one attention operation to the next attention operation, is a submodule. + +Every submodule can be identified by its index, and will be processed individually. + +## Computation Graph Compilation + +In the very verbose logs, we can also see: + +``` +DEBUG 03-07 03:52:37 [backends.py:134] store the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py') +DEBUG 03-07 03:52:39 [backends.py:134] store the 1-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py') +... +DEBUG 03-07 03:52:45 [backends.py:134] store the 15-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py') +DEBUG 03-07 03:52:45 [backends.py:134] store the 16-th graph for shape None from inductor via handle ('fvj3ccoi7m34f3dnr4itmu55mmun44l5xymwhrjlwisylsk7q6jy', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/tf/ctfftkglj7b4lcttq5cymx6cew372uoauupqn6ldsvpiucavqcjc.py') +``` + +This means, the first piece of computation graph, with shape `None` (for symbolic shape), is compiled by Inductor, with a key `fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw`, and the compiled kernel is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py`. You can open the file to see what is the code Inductor finally runs. + +One more detail: you can see that the 1-th graph and the 15-th graph have the same key, while the 0-th graph and the 16-th graph are different. This is expected, since we split the graph by the attention op, we get 3 unique subgraphs: + +- the first layer before attention +- every middle layer, from one attention operation to the next attention operation +- the final layer after attention + +If we already have the cache directory (e.g. run the same code for the second time), we will see the following logs: + +``` +DEBUG 03-07 04:00:45 [backends.py:86] Directly load the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py') +``` + +This time, Inductor compilation is completely bypassed, and we will load from disk to read the compilation artifact we get from the last time. + +The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example: + +`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` + +Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel. + +When all the shapes are known, `torch.compile` can compare different configs, and often find some better configs to run the kernel. For example, we can see the following log: + +``` +AUTOTUNE mm(8x2048, 2048x3072) + triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2 + triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 + triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 + mm 0.0160 ms 81.6% + triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 + triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2 + triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2 + triton_mm_7 0.0203 ms 64.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 + triton_mm_2 0.0208 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 + triton_mm_11 0.0215 ms 60.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 +SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds precompiling +``` + +It means, for a matrix multiplication with shape `8x2048x3072`, `torch.compile` tries triton template with various configs, and it is much faster than the default code (which dispatches to cublas library). + +Unfortunately, because auto-tuning takes quite a long time (even though it can be cached for later use), for the sake of user-friendliness, we turn it off by default. If you want to have max performance, it is recommended to try it, by compiling specific shapes. + +## Cudagraph Capture + +vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on the common observation, that computation between attention are usually token-wise, and easy to deal with for cudagraph, while the attention operation is non-trival to be cudagraph compatible. By running the attention operation in eager mode, we keep the flexibility of the attention operation, while the rest operations run in cudagraph mode. + +The piecewise cudagraph also has fine-grained memory management. The purpose is to only exclude the attention kernel from cudagraph, while keeping all the rest modules and the memory allocation operations in the cudagraph. This is why the attention operation in V1 has the output tensor as the input of the attention. + +The cudagraphs are captured and managed by the compiler backend, and replayed when the batch size has corresponding cudagraph captured. The caller of the model (model runner) only needs to make sure it manages the input buffers correctly. All of the intermediate buffers are managed automatically by the compiler backend. + +By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`: + +`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` + +Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. From 71e9eeb6c271b5f32d11fbefa57d0e1dde874f35 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:22:18 +0800 Subject: [PATCH 4/9] explain how to disable Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index a82ad49881e5..978f2fae2200 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -14,6 +14,8 @@ INFO 03-07 03:06:55 [backends.py:409] Using cache directory: ~/.cache/vllm/torch vLLM will take all the available factors into consideration, and decide a directory to store all the compilation artifact. This means, you can directly copy the whole `~/.cache/vllm/torch_compile_cache` directory in your deployment scenario to save a great amount of compilation time, and hence accelerating the starting time of the vLLM instance. +The compilation cache is enabled by default. You can disable it by setting `VLLM_DISABLE_COMPILE_CACHE=1`. This is useful if you want to debug the compilation process, or if you suspect the cache is causing some issues. + A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes. ## Python Code Compilation From 62c0ee3a8b8242ce241f8d6367669861efdb531b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:28:36 +0800 Subject: [PATCH 5/9] explain factors Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 978f2fae2200..b74757b0db1b 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -14,7 +14,13 @@ INFO 03-07 03:06:55 [backends.py:409] Using cache directory: ~/.cache/vllm/torch vLLM will take all the available factors into consideration, and decide a directory to store all the compilation artifact. This means, you can directly copy the whole `~/.cache/vllm/torch_compile_cache` directory in your deployment scenario to save a great amount of compilation time, and hence accelerating the starting time of the vLLM instance. -The compilation cache is enabled by default. You can disable it by setting `VLLM_DISABLE_COMPILE_CACHE=1`. This is useful if you want to debug the compilation process, or if you suspect the cache is causing some issues. +The factors considered include: + +- All the related configs (see the `compute_hash` functions in the [config.py](gh-file:vllm/vllm/config.py)) +- PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](gh-file:vllm/compilation/compiler_interface.py)) +- The model's forward function and the relevant functions called by the forward function (see below) + +With all these factors taken into consideration, usually we can guarantee that the cache is safe to use, and will not cause any unexpected behavior. Therefore, the cache is enabled by default. If you want to debug the compilation process, or if you suspect the cache is causing some issues, you can disable it by setting the environment variable `VLLM_DISABLE_COMPILE_CACHE=1`. A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes. From 24aa01b960eef3dfb8d6b6a17d3b5a166bfa25d4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:30:01 +0800 Subject: [PATCH 6/9] apply Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index b74757b0db1b..55ad94eb8c90 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -81,7 +81,7 @@ DEBUG 03-07 03:52:45 [backends.py:134] store the 15-th graph for shape None from DEBUG 03-07 03:52:45 [backends.py:134] store the 16-th graph for shape None from inductor via handle ('fvj3ccoi7m34f3dnr4itmu55mmun44l5xymwhrjlwisylsk7q6jy', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/tf/ctfftkglj7b4lcttq5cymx6cew372uoauupqn6ldsvpiucavqcjc.py') ``` -This means, the first piece of computation graph, with shape `None` (for symbolic shape), is compiled by Inductor, with a key `fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw`, and the compiled kernel is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py`. You can open the file to see what is the code Inductor finally runs. +This means the first piece of computation graph (with shape `None` for symbolic shape) is compiled by Inductor (with a key `fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw`). The compiled kernel is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py`. You can open the file to see what is the code Inductor finally runs. One more detail: you can see that the 1-th graph and the 15-th graph have the same key, while the 0-th graph and the 16-th graph are different. This is expected, since we split the graph by the attention op, we get 3 unique subgraphs: From 8e32aa66307a1cb6020e7fbe27cc8a9f402ae71d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:31:42 +0800 Subject: [PATCH 7/9] apply Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 55ad94eb8c90..d1cbc149fc38 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -126,7 +126,7 @@ Unfortunately, because auto-tuning takes quite a long time (even though it can b ## Cudagraph Capture -vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on the common observation, that computation between attention are usually token-wise, and easy to deal with for cudagraph, while the attention operation is non-trival to be cudagraph compatible. By running the attention operation in eager mode, we keep the flexibility of the attention operation, while the rest operations run in cudagraph mode. +vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trival to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation. The piecewise cudagraph also has fine-grained memory management. The purpose is to only exclude the attention kernel from cudagraph, while keeping all the rest modules and the memory allocation operations in the cudagraph. This is why the attention operation in V1 has the output tensor as the input of the attention. From 0f30bdd0f6e5560ed772a37a675aeaff0329c532 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:33:55 +0800 Subject: [PATCH 8/9] typical time Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index d1cbc149fc38..f0a0e50217ae 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -122,7 +122,7 @@ SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds prec It means, for a matrix multiplication with shape `8x2048x3072`, `torch.compile` tries triton template with various configs, and it is much faster than the default code (which dispatches to cublas library). -Unfortunately, because auto-tuning takes quite a long time (even though it can be cached for later use), for the sake of user-friendliness, we turn it off by default. If you want to have max performance, it is recommended to try it, by compiling specific shapes. +Unfortunately, because auto-tuning takes quite a long time (from seconds to minutes, depending on the model size and the batch size), even though it can be cached for later use, for the sake of user-friendliness, we turn it off by default. If you want to have max performance, it is recommended to try it, by compiling specific shapes. ## Cudagraph Capture From 3ee1f50f2831e341ecc051d31328a94f5038e847 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Mar 2025 01:51:08 +0800 Subject: [PATCH 9/9] fix path Signed-off-by: youkaichao --- docs/source/design/v1/torch_compile.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index f0a0e50217ae..0dadc8089991 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -16,7 +16,7 @@ vLLM will take all the available factors into consideration, and decide a direct The factors considered include: -- All the related configs (see the `compute_hash` functions in the [config.py](gh-file:vllm/vllm/config.py)) +- All the related configs (see the `compute_hash` functions in the [config.py](gh-file:vllm/config.py)) - PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](gh-file:vllm/compilation/compiler_interface.py)) - The model's forward function and the relevant functions called by the forward function (see below)