Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/platforms/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Pallas MOE implementation.

Run `pytest tests/platforms/test_tpu.py`.
"""
from unittest.mock import MagicMock, patch

import pytest

import vllm.config
from vllm.platforms.tpu import TpuPlatform


@pytest.mark.parametrize(
"use_v1,initial_block_size,expected_block_size",
[
(True, 32, 32), # Case 1: v1: block_size set, should remain unchanged
(
True, None, 128
), # Case 2: v1: block_size None, should be set to get_page_size (128)
(False, None, 16), # Case 3: v0: block_size None, should be set to 16
(False, 32, 32), # Case 4: v0: block_size set, should remain unchanged
])
@patch(
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size",
return_value=128)
@patch(
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size",
return_value=8)
def test_tpu_platform_update_vllm_config_block_size_respect_passin_block_size(
mock_get_min_page_size, mock_get_page_size, use_v1, initial_block_size,
expected_block_size) -> None:
"""Test TPU platform updates VLLM config with block size."""
# arrange
mock_cached_config = MagicMock()
mock_cached_config.block_size = initial_block_size

mock_model_config = MagicMock()
mock_model_config.dtype = "float16"

mock_vllm_config = MagicMock()
mock_vllm_config.cache_config = mock_cached_config
mock_vllm_config.compilation_config = MagicMock()
mock_vllm_config.compilation_config.level = (
vllm.config.CompilationLevel.DYNAMO_ONCE)
mock_vllm_config.compilation_config.backend = "openxla"
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.speculative_config = None
mock_vllm_config.parallel_config = MagicMock()
mock_vllm_config.parallel_config.worker_cls = (
"vllm.v1.worker.tpu_worker.TPUWorker")
mock_vllm_config.scheduler_config = MagicMock()

# act
with patch("vllm.envs.VLLM_USE_V1", use_v1):
TpuPlatform.check_and_update_config(mock_vllm_config)

# assert
assert mock_cached_config.block_size == expected_block_size
if use_v1:
mock_get_min_page_size.assert_called()
if initial_block_size is None:
mock_get_page_size.assert_called()
else:
mock_get_page_size.assert_not_called()
20 changes: 14 additions & 6 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel

cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = cast(BlockSize, 16)
assert cache_config is not None

compilation_config = vllm_config.compilation_config

# TPU only supports DYNAMO_ONCE compilation level
Expand All @@ -118,8 +117,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]

# For v1, the default block size is calculated from vllm_config.
cache_config.block_size = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add warning when the block_size set by user is not optimal? cc @bythew3i

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know what value is optimal for the user, I think we should add.

The reason I want to put back the block-size parameter is that I found the value from PallasAttentionBackend.get_page_size is not a good value for me when my max-model-len is large.

In my benchmark testing, my max-model-len is 4096 which results a 256 from get_page_size. But block_size=16 gives me better result.

cache_config.block_size
or PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
)

min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
Expand All @@ -130,7 +135,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
min_page_size,
)
cache_config.block_size = min_page_size # type: ignore[assignment]

else:
# For v0, the default block size is 16.
cache_config.block_size = (cache_config.block_size
or cast(BlockSize, 16))
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
Expand Down