diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 2c356b8fe527..89252000f400 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_3 \ && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ && echo TEST_4 \ - && python3 /workspace/vllm/examples/offline_inference/tpu.py \ + && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ && echo TEST_5 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ + && python3 /workspace/vllm/examples/offline_inference/tpu.py \ && echo TEST_6 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ + && echo TEST_7 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling # && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ - -# TODO: Re-enable this after fixing recompilation in quantization. -# && echo TEST_4 \ -# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 0bf090d7fab3..089314071d39 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -97,7 +97,8 @@ def apply_weights(self, block_size=-1, int4_weight=False, quantize_activation=True) - + # `quantized_matmul` output is fp32, cast it down to bf16 for perf + out = out.to(x.dtype) # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5401fff2bf19..4369d5a14af5 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -80,6 +80,7 @@ def __init__( self.enforce_eager = model_config.enforce_eager self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype + self._hidden_states_dtype = self.dtype self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() @@ -758,10 +759,11 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + out = self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) + self._hidden_states_dtype = out.dtype def capture_model(self) -> None: """Compile the model.""" @@ -787,7 +789,7 @@ def capture_model(self) -> None: num_reqs_to_sample = MIN_NUM_SEQS dummy_hidden = torch.randn((num_tokens, hsize), device=device, - dtype=torch.bfloat16) + dtype=self._hidden_states_dtype) # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] while True: indices = torch.zeros( @@ -810,7 +812,7 @@ def capture_model(self) -> None: num_reqs_to_sample + 1, self.max_num_reqs) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) + logger.info("Compilation finished in %.2f [secs].", end - start) # Record the number cached XLA graph after warming up, this will be # used for checking there is no additional graph compilation during # runtime execution. diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 4d9a113e39ee..7c81d1365628 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -105,8 +105,8 @@ def init_device(self): # Increase the cache size limit, which is the maximum number of # dynamo graphs that can be compiled. - # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and - # 30-40 graphs for decode. 128 is an arbitrary safe number. + # TODO (NickLucche) On gsm we compile 80+ graphs. + # Re-evaluate limit, with MM we may get close to this limit. torch._dynamo.config.cache_size_limit = 128 # Use persistent cache to avoid XLA recompilation. # NOTE(woosuk): Set per-rank cache path since different ranks