Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Mar 18, 2025

Carrying on @mgoin work here #14254.

This PR pre-compiles the multimodal encoder for various mm items batch sizes. We're mostly concerned with LLava-style vision language models (fixed input resolution), but pre-compiling code is already considering other modalities (audio etc..). We're not addressing dynamic image sizes models just yet (Pixtral etc..).

Server:

VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 2512 --max-num-seqs 16 --max-num-batched-tokens 128 --chat-template examples/template_llava.jinja 

INFO 03-18 17:17:42 [tpu_model_runner.py:816] Encoder cache will be initialized with a budget of 576 tokens, and profiled with 1 image items of the maximum feature size.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO 03-18 17:17:47 [tpu_model_runner.py:837] Multimodal Encoder profiling finished in in 4.21 [secs].
INFO 03-18 17:18:04 [kv_cache_utils.py:537] GPU KV cache size: 28,976 tokens
INFO 03-18 17:18:04 [kv_cache_utils.py:540] Maximum concurrency for 2,512 tokens per request: 11.54x
INFO 03-18 17:18:04 [tpu_model_runner.py:865] Compiling Multimodal image Encoder with different input shapes.
INFO 03-18 17:18:04 [tpu_model_runner.py:871]   -- mode: image items: 1
INFO 03-18 17:18:08 [tpu_model_runner.py:887] Multimodal image Encoder compilation finished in in 3.90 [secs].
INFO 03-18 17:18:08 [tpu_model_runner.py:891] Compiling the model with different input shapes.
INFO 03-18 17:18:37 [tpu_model_runner.py:896]   -- num_tokens: 16
INFO 03-18 17:18:47 [tpu_model_runner.py:896]   -- num_tokens: 32
INFO 03-18 17:18:57 [tpu_model_runner.py:896]   -- num_tokens: 64
INFO 03-18 17:19:08 [tpu_model_runner.py:896]   -- num_tokens: 128
INFO 03-18 17:19:08 [tpu_model_runner.py:903] Compilation finished in in 60.14 [secs].
INFO 03-18 17:19:08 [core.py:138] init engine (profile, create kv cache, warmup model) took 85.70 seconds

Client:

python examples/online_serving/openai_chat_completion_client_for_multimodal.py
Chat completion output from base64 encoded image: The image features a wooden pathway or boardwalk that leads through a lush green field. The pathway is surrounded by tall grass, creating a serene [...]

Benchmark:

vllm serve llava-hf/llava-1.5-7b-hf \
  --max-model-len 4096 --max-num-seqs 32 \
  --port 8004 \
  --disable-log-requests \
  --max-num-batched-tokens 1024 --disable_chunked_mm_input
============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  43.74     
Total input tokens:                      1388      
Total generated tokens:                  7750      
Request throughput (req/s):              2.29      
Output token throughput (tok/s):         177.18    
Total Token throughput (tok/s):          208.91    
---------------Time to First Token----------------
Mean TTFT (ms):                          18875.39  
Median TTFT (ms):                        18757.64  
P99 TTFT (ms):                           38747.09  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          152.70    
Median TPOT (ms):                        165.25    
P99 TPOT (ms):                           192.38    
---------------Inter-token Latency----------------
Mean ITL (ms):                           149.20    
Median ITL (ms):                         77.75     
P99 ITL (ms):                            312.47    
==================================================

Update: The issues listed below have all been addressed by separate PRs. In particular, we work around issue 1 by only processing whole mm_items, as the gather_mm_placeholders operation is intrinsically dynamic.


I do want to discuss a few issues we have with code manipulating on device tensors that will potentially cause re-compilation (here as well as on main) related to this PR :

in _gather_encoder_outputs:

end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])

Last line will slice an on-device tensor with varying shape. Padding here is non-obvious because image features have to be aligned with image placeholders in input_ids.


# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)

This dynamically filters out the on-device patches. The code is inside the model definition and is duplicated across different models.


Finally, this is not a real re-compilation issue but just wanted to point out the CLIP-based model I've tested will have it's graph broken at

out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)

as XLA debug logs will reveal an access to an on device tensor forcing compilation to start.

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   most likely user code trying to access tensor value before mark_step
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: dc89cd1bf4aeecbe31724b74e7832de9
Compilation Analysis:   Number of Graph Inputs: 12
Compilation Analysis:   Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   forward (/home/nick/vllm/vllm/attention/layer.py:313)
Compilation Analysis:   _call_impl (/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762)
Compilation Analysis:   _wrapped_call_impl (/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751)
Compilation Analysis:   forward (/home/nick/vllm/vllm/model_executor/models/clip.py:146)
Compilation Analysis:   _call_impl (/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762)
Compilation Analysis:   _wrapped_call_impl (/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751)
Compilation Analysis:   forward (/home/nick/vllm/vllm/model_executor/models/clip.py:209)
Compilation Analysis:   _call_impl (/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762)
Compilation Analysis:   ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 18, 2025
@NickLucche
Copy link
Collaborator Author

PS I want to add tests to this PR, just having some issues getting a test with RemoteOpenAIServer to run..

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche thanks for doing this! In general, LGTM, left some questions/comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids does not require this zeroing out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great observation thanks!
This is definitely happening, I am just unsure if the fact that it's being added to a padding token will make its positioning not meaningful @robertgshaw2-redhat

Values at runtime
position_ids = [512, 513, 514, 515, . . . , 594, 595, padding starts=>468, 469, 470, ... , 511]

@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from 4c3ff9d to 50e76ad Compare March 21, 2025 15:32
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche do you want to revive and test this now that the Pallas MHA backend has landed? It should be sufficient to show a benchmark comparing llava with a single image before and after this PR

@mergify
Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@NickLucche
Copy link
Collaborator Author

@mgoin Totally I had this lined up right after main branch is stable. Main issue is work around the two causes of recompilations I listed above, now that we don't allow any in the CI anymore.

@mergify mergify bot added the tpu Related to Google TPUs label Mar 27, 2025
@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from 50e76ad to 285f084 Compare March 31, 2025 17:13
@mergify mergify bot removed the needs-rebase label Mar 31, 2025
@mergify
Copy link

mergify bot commented Mar 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2025
@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from e5e6cd5 to c85c6f9 Compare March 31, 2025 17:46
@mergify mergify bot removed the needs-rebase label Mar 31, 2025
@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from 037d7d4 to 5ed00e5 Compare April 8, 2025 16:27
@NickLucche
Copy link
Collaborator Author

NickLucche commented Apr 9, 2025

There's still some recompilation on smaller graphs to address, but this PR should now be ready (thanks @mgoin and @DarkLight1337 for the work on mm scatter-gather and scheduler).

Test with

VLLM_XLA_CACHE_PATH=  VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 2512 --max-num-seqs 8 --max-num-batched-tokens 1024 --chat-template examples/template_llava.jinja --disable_chunked_mm_input

@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from 16eefea to c4f4280 Compare April 10, 2025 14:29
@NickLucche
Copy link
Collaborator Author

VLLM_XLA_CACHE_PATH= VLLM_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 python -m pytest -s tests/v1/tpu/test_multimodal.py working.
Let's get this merged.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 10, 2025
@mergify mergify bot added the ci/build label Apr 10, 2025
@lsy323
Copy link
Collaborator

lsy323 commented Apr 16, 2025

Hi @NickLucche, I debugged the recompilation issue on my end and here are the findings and the fix. You can check this commit which includes the fix.

  1. There is a graph break after the mm embedding is computed / gathered from cache during precompile, however, this graph break is missing in execute_model. I added it here

  2. self.model.get_multimodal_embeddings returns a torch.Tensor or a list/tuple of torch.Tensor ref. Therefore the logic in the current code is incorrect, the current logic needs to be extended to handle torch.Tensor case as well. I added the handling here

I think we have the recompilation issue, and it's not straightforward to debug, because the code path between precompile and execution are not the same. I think we can abstract out the function of device operations, and then call the same function during precompile and execute_model. So that our compilation workflow is more robust, and it will be easier for debugging.

To debug this, I added more recompilation check code throughout the execute_model function, and runs the workload with the following debug, XLA_FLAGS='--xla_dump_to=/tmp/vllm_hlo' PT_XLA_DEBUG_LEVEL=2 XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 . Combing the IR and the debug print, we can locate where the recompilation happens. Just would like to share my experience in case it's helpful.

@NickLucche
Copy link
Collaborator Author

Thanks for the fix and reporting the debugging method too!

@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from 47845e0 to a442a8b Compare April 17, 2025 13:24
@NickLucche NickLucche requested a review from mgoin April 17, 2025 14:44
@NickLucche NickLucche force-pushed the tpu-multimodal-encoder-compile2 branch from a442a8b to 4c13715 Compare April 18, 2025 09:34
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin mgoin merged commit 2102075 into vllm-project:main Apr 22, 2025
42 checks passed
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
…roject#15051)

Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Signed-off-by: Frieda (Jingying) Huang <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…roject#15051)

Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…roject#15051)

Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
…roject#15051)

Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…roject#15051)

Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants