Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .buildkite/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_3 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
&& echo TEST_4 \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& echo TEST_5 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \


# TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \

# TODO: Re-enable this after fixing recompilation in quantization.
# && echo TEST_4 \
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def apply_weights(self,
block_size=-1,
int4_weight=False,
quantize_activation=True)

# `quantized_matmul` output is fp32, cast it down to bf16 for perf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be precise, quantized_matmul is not always generating fp32 output, for some models/weights we generate fp32 output because scaler is fp32, therefore the output from quantized matmul is promoted to fp32. We cast it back to x.dtype to ensure the activation dtype remain the same throughout the model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely!

out = out.to(x.dtype)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
14 changes: 8 additions & 6 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as model_config.dtype specifies the activation dtype, which is the same dtype as the hidden_states_dtype here.

ref: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L131-L133

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid it's not as I did try running with self.dtype before the PR, but it would just evaluate to bf16. While the quantized_matmul, as you noted below, does and implicit cast to fp32 which was not accounted for in model_config.dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model output generates a different dtype than the one specified in the config, then we have a bug in our model code. Supposely _hidden_states_dtype and dtype should always be the same


self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
Expand Down Expand Up @@ -758,10 +759,11 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)

with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype

def capture_model(self) -> None:
"""Compile the model."""
Expand All @@ -787,7 +789,7 @@ def capture_model(self) -> None:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
Expand All @@ -810,7 +812,7 @@ def capture_model(self) -> None:
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compilation finished in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def init_device(self):

# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
Expand Down