Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Mar 28, 2025

This PR fixes the compilation issue we run into with VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 python -m pytest -s -v tests/tpu/test_quantization_accuracy.py with the w8a8 llama model.
The actual output dtype may vary depending on the quantization mm we use. In this case, the hidden states are fp32.
Hence, we pre-compile with dummy fp32 hidden states rather than assuming bf16.

Update:

To provide more context here: the "culprit" of the change is torch.ops.xla.quantized_matmul which returns fp32. Casting it down to bf16 should leverage TPU capabilities better in intermediate layers. Results:

pre
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  171.10    
Total input tokens:                      1720000   
Total generated tokens:                  120874    
Request throughput (req/s):              5.84      
Output token throughput (tok/s):         706.46    
Total Token throughput (tok/s):          10759.14  
---------------Time to First Token----------------
Mean TTFT (ms):                          84022.40  
Median TTFT (ms):                        83897.42  
P99 TTFT (ms):                           166727.15 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.75     
Median TPOT (ms):                        45.22     
P99 TPOT (ms):                           45.90     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.73     
Median ITL (ms):                         45.18     
P99 ITL (ms):                            47.79     
==================================================

post 
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  170.58    
Total input tokens:                      1720000   
Total generated tokens:                  120508    
Request throughput (req/s):              5.86      
Output token throughput (tok/s):         706.46    
Total Token throughput (tok/s):          10789.74  
---------------Time to First Token----------------
Mean TTFT (ms):                          83928.69  
Median TTFT (ms):                        83772.92  
P99 TTFT (ms):                           166358.23 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.66     
Median TPOT (ms):                        45.12     
P99 TPOT (ms):                           45.76     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.67     
Median ITL (ms):                         45.10     
P99 ITL (ms):                            46.79     
==================================================

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added ci/build v1 tpu Related to Google TPUs labels Mar 28, 2025
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@alexm-redhat alexm-redhat enabled auto-merge (squash) March 28, 2025 17:11
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 28, 2025
@mgoin mgoin added the bug Something isn't working label Mar 28, 2025
@NickLucche
Copy link
Collaborator Author

let's hold for benchmarks

@yaochengji yaochengji self-assigned this Mar 28, 2025
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

I'm thinking about the root cause of the recompilation is due to we don't use the same code in execute_model and dummy_run. To prevent further similar bugs, can we use a common utility function which wraps all the tpu computations including encoder, self.model, logits processor / samper and make sure it is shared by both?

cc @robertgshaw2-redhat , then the disabling sampler logic don't need to apply to two places in your PR, #15662

dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
dtype=self._output_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we directly use the output of the self.model call, then we don't need such a dtype argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either we pay an extra forward or the whole loop moves into dummy_run

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just respecting the function signature, so nothing is returned directly from dummy_run

@NickLucche
Copy link
Collaborator Author

use the same code in execute_model and dummy_run

I am not sure, while I'd definitely like a cleaner structure, I think to match what we do in gpu model runner dummy_run only runs the decoder in mm.

Also, we're already underestimating memory as @robertgshaw2-redhat found out in his PR.

auto-merge was automatically disabled March 28, 2025 18:41

Head branch was pushed to by a user without write access

@NickLucche NickLucche requested a review from tlrmchlsmth as a code owner March 28, 2025 18:41
@mgoin
Copy link
Member

mgoin commented Mar 28, 2025

The gpu model runner has separate get_multimodal_embeddings, _dummy_run, and _dummy_sampler_run called for profiling, so I agree with Nicolo about staying true to gpu for now. I particularly don't see how it makes sense to lump mm encoder and decoder together

@yaochengji
Copy link
Collaborator

The gpu model runner has separate get_multimodal_embeddings, _dummy_run, and _dummy_sampler_run called for profiling

The main difference is that TPU has to pre-compile for all the computations on device, not only the transformer backbone. But I'm fine with current situation as long as we have the recompilation check in our CI test.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this.

self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as model_config.dtype specifies the activation dtype, which is the same dtype as the hidden_states_dtype here.

ref: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L131-L133

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid it's not as I did try running with self.dtype before the PR, but it would just evaluate to bf16. While the quantized_matmul, as you noted below, does and implicit cast to fp32 which was not accounted for in model_config.dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model output generates a different dtype than the one specified in the config, then we have a bug in our model code. Supposely _hidden_states_dtype and dtype should always be the same

int4_weight=False,
quantize_activation=True)

# `quantized_matmul` output is fp32, cast it down to bf16 for perf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be precise, quantized_matmul is not always generating fp32 output, for some models/weights we generate fp32 output because scaler is fp32, therefore the output from quantized matmul is promoted to fp32. We cast it back to x.dtype to ensure the activation dtype remain the same throughout the model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely!

@vllm-bot vllm-bot merged commit da461f3 into vllm-project:main Mar 29, 2025
36 of 38 checks passed
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants