-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K #15714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K #15714
Conversation
Signed-off-by: NickLucche <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
alexm-redhat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
|
let's hold for benchmarks |
yaochengji
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
I'm thinking about the root cause of the recompilation is due to we don't use the same code in execute_model and dummy_run. To prevent further similar bugs, can we use a common utility function which wraps all the tpu computations including encoder, self.model, logits processor / samper and make sure it is shared by both?
cc @robertgshaw2-redhat , then the disabling sampler logic don't need to apply to two places in your PR, #15662
vllm/v1/worker/tpu_model_runner.py
Outdated
| dummy_hidden = torch.randn((num_tokens, hsize), | ||
| device=device, | ||
| dtype=torch.bfloat16) | ||
| dtype=self._output_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we directly use the output of the self.model call, then we don't need such a dtype argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either we pay an extra forward or the whole loop moves into dummy_run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just respecting the function signature, so nothing is returned directly from dummy_run
I am not sure, while I'd definitely like a cleaner structure, I think to match what we do in gpu model runner dummy_run only runs the decoder in mm. Also, we're already underestimating memory as @robertgshaw2-redhat found out in his PR. |
Signed-off-by: NickLucche <[email protected]>
Head branch was pushed to by a user without write access
|
The gpu model runner has separate get_multimodal_embeddings, _dummy_run, and _dummy_sampler_run called for profiling, so I agree with Nicolo about staying true to gpu for now. I particularly don't see how it makes sense to lump mm encoder and decoder together |
The main difference is that TPU has to pre-compile for all the computations on device, not only the transformer backbone. But I'm fine with current situation as long as we have the recompilation check in our CI test. |
yaochengji
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing this.
| self.enforce_eager = model_config.enforce_eager | ||
| self.pin_memory = is_pin_memory_available() | ||
| self.dtype = self.model_config.dtype | ||
| self._hidden_states_dtype = self.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as model_config.dtype specifies the activation dtype, which is the same dtype as the hidden_states_dtype here.
ref: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L131-L133
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am afraid it's not as I did try running with self.dtype before the PR, but it would just evaluate to bf16. While the quantized_matmul, as you noted below, does and implicit cast to fp32 which was not accounted for in model_config.dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the model output generates a different dtype than the one specified in the config, then we have a bug in our model code. Supposely _hidden_states_dtype and dtype should always be the same
| int4_weight=False, | ||
| quantize_activation=True) | ||
|
|
||
| # `quantized_matmul` output is fp32, cast it down to bf16 for perf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be precise, quantized_matmul is not always generating fp32 output, for some models/weights we generate fp32 output because scaler is fp32, therefore the output from quantized matmul is promoted to fp32. We cast it back to x.dtype to ensure the activation dtype remain the same throughout the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolutely!
) Signed-off-by: NickLucche <[email protected]> Signed-off-by: xinyuxiao <[email protected]>
) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
) Signed-off-by: NickLucche <[email protected]>
) Signed-off-by: NickLucche <[email protected]>
) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Mu Huai <[email protected]>
This PR fixes the compilation issue we run into with
VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 python -m pytest -s -v tests/tpu/test_quantization_accuracy.pywith the w8a8 llama model.The actual output dtype may vary depending on the quantization mm we use. In this case, the hidden states are fp32.
Hence, we pre-compile with dummy fp32 hidden states rather than assuming bf16.
Update:
To provide more context here: the "culprit" of the change is
torch.ops.xla.quantized_matmulwhich returns fp32. Casting it down to bf16 should leverage TPU capabilities better in intermediate layers. Results: