Skip to content

Commit da461f3

Browse files
authored
[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by: NickLucche <[email protected]>
1 parent 5b800f0 commit da461f3

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

.buildkite/run-tpu-v1-test.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
2828
&& echo TEST_3 \
2929
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
3030
&& echo TEST_4 \
31-
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
31+
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
3232
&& echo TEST_5 \
33-
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
33+
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
3434
&& echo TEST_6 \
35+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
36+
&& echo TEST_7 \
3537
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
3638

3739

3840
# TODO: This test fails because it uses RANDOM_SEED sampling
3941
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
40-
41-
# TODO: Re-enable this after fixing recompilation in quantization.
42-
# && echo TEST_4 \
43-
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \

vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def apply_weights(self,
9797
block_size=-1,
9898
int4_weight=False,
9999
quantize_activation=True)
100-
100+
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
101+
out = out.to(x.dtype)
101102
# Explicitly capture control flow to make dynamo happy.
102103
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
103104
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

vllm/v1/worker/tpu_model_runner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self.enforce_eager = model_config.enforce_eager
8181
self.pin_memory = is_pin_memory_available()
8282
self.dtype = self.model_config.dtype
83+
self._hidden_states_dtype = self.dtype
8384

8485
self.is_multimodal_model = model_config.is_multimodal_model
8586
self.sliding_window = model_config.get_sliding_window()
@@ -771,10 +772,11 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None:
771772
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
772773

773774
with set_forward_context(attn_metadata, self.vllm_config, 0):
774-
self.model(input_ids=input_ids,
775-
positions=position_ids,
776-
kv_caches=kv_caches,
777-
inputs_embeds=inputs_embeds)
775+
out = self.model(input_ids=input_ids,
776+
positions=position_ids,
777+
kv_caches=kv_caches,
778+
inputs_embeds=inputs_embeds)
779+
self._hidden_states_dtype = out.dtype
778780

779781
def capture_model(self) -> None:
780782
"""Compile the model."""
@@ -800,7 +802,7 @@ def capture_model(self) -> None:
800802
num_reqs_to_sample = MIN_NUM_SEQS
801803
dummy_hidden = torch.randn((num_tokens, hsize),
802804
device=device,
803-
dtype=torch.bfloat16)
805+
dtype=self._hidden_states_dtype)
804806
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
805807
while True:
806808
indices = torch.zeros(
@@ -823,7 +825,7 @@ def capture_model(self) -> None:
823825
num_reqs_to_sample + 1, self.max_num_reqs)
824826
xm.wait_device_ops()
825827
end = time.perf_counter()
826-
logger.info("Compilation finished in in %.2f [secs].", end - start)
828+
logger.info("Compilation finished in %.2f [secs].", end - start)
827829
# Record the number cached XLA graph after warming up, this will be
828830
# used for checking there is no additional graph compilation during
829831
# runtime execution.

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def init_device(self):
105105

106106
# Increase the cache size limit, which is the maximum number of
107107
# dynamo graphs that can be compiled.
108-
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
109-
# 30-40 graphs for decode. 128 is an arbitrary safe number.
108+
# TODO (NickLucche) On gsm we compile 80+ graphs.
109+
# Re-evaluate limit, with MM we may get close to this limit.
110110
torch._dynamo.config.cache_size_limit = 128
111111
# Use persistent cache to avoid XLA recompilation.
112112
# NOTE(woosuk): Set per-rank cache path since different ranks

0 commit comments

Comments
 (0)