-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K #15714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,7 @@ def __init__( | |
| self.enforce_eager = model_config.enforce_eager | ||
| self.pin_memory = is_pin_memory_available() | ||
| self.dtype = self.model_config.dtype | ||
| self._hidden_states_dtype = self.dtype | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is the same as ref: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L131-L133 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid it's not as I did try running with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the model output generates a different dtype than the one specified in the config, then we have a bug in our model code. Supposely |
||
|
|
||
| self.is_multimodal_model = model_config.is_multimodal_model | ||
| self.sliding_window = model_config.get_sliding_window() | ||
|
|
@@ -758,10 +759,11 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: | |
| torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) | ||
|
|
||
| with set_forward_context(attn_metadata, self.vllm_config, 0): | ||
| self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| kv_caches=kv_caches, | ||
| inputs_embeds=inputs_embeds) | ||
| out = self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| kv_caches=kv_caches, | ||
| inputs_embeds=inputs_embeds) | ||
| self._hidden_states_dtype = out.dtype | ||
|
|
||
| def capture_model(self) -> None: | ||
| """Compile the model.""" | ||
|
|
@@ -787,7 +789,7 @@ def capture_model(self) -> None: | |
| num_reqs_to_sample = MIN_NUM_SEQS | ||
| dummy_hidden = torch.randn((num_tokens, hsize), | ||
| device=device, | ||
| dtype=torch.bfloat16) | ||
| dtype=self._hidden_states_dtype) | ||
| # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] | ||
| while True: | ||
| indices = torch.zeros( | ||
|
|
@@ -810,7 +812,7 @@ def capture_model(self) -> None: | |
| num_reqs_to_sample + 1, self.max_num_reqs) | ||
| xm.wait_device_ops() | ||
| end = time.perf_counter() | ||
| logger.info("Compilation finished in in %.2f [secs].", end - start) | ||
| logger.info("Compilation finished in %.2f [secs].", end - start) | ||
| # Record the number cached XLA graph after warming up, this will be | ||
| # used for checking there is no additional graph compilation during | ||
| # runtime execution. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be precise,
quantized_matmulis not always generatingfp32output, for some models/weights we generatefp32output because scaler isfp32, therefore the output from quantized matmul is promoted tofp32. We cast it back tox.dtypeto ensure the activation dtype remain the same throughout the modelThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolutely!