-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[TPU] Calculate block size only when not set. #18292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Tests for the Pallas MOE implementation. | ||
|
||
Run `pytest tests/platforms/test_tpu.py`. | ||
""" | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
import vllm.config | ||
from vllm.platforms.tpu import TpuPlatform | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"use_v1,initial_block_size,expected_block_size", | ||
[ | ||
(True, 32, 32), # Case 1: v1: block_size set, should remain unchanged | ||
( | ||
True, None, 128 | ||
), # Case 2: v1: block_size None, should be set to get_page_size (128) | ||
(False, None, 16), # Case 3: v0: block_size None, should be set to 16 | ||
(False, 32, 32), # Case 4: v0: block_size set, should remain unchanged | ||
]) | ||
@patch( | ||
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size", | ||
return_value=128) | ||
@patch( | ||
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size", | ||
return_value=8) | ||
def test_tpu_platform_update_vllm_config_block_size_respect_passin_block_size( | ||
mock_get_min_page_size, mock_get_page_size, use_v1, initial_block_size, | ||
expected_block_size) -> None: | ||
"""Test TPU platform updates VLLM config with block size.""" | ||
# arrange | ||
mock_cached_config = MagicMock() | ||
mock_cached_config.block_size = initial_block_size | ||
|
||
mock_model_config = MagicMock() | ||
mock_model_config.dtype = "float16" | ||
|
||
mock_vllm_config = MagicMock() | ||
mock_vllm_config.cache_config = mock_cached_config | ||
mock_vllm_config.compilation_config = MagicMock() | ||
mock_vllm_config.compilation_config.level = ( | ||
vllm.config.CompilationLevel.DYNAMO_ONCE) | ||
mock_vllm_config.compilation_config.backend = "openxla" | ||
mock_vllm_config.model_config = mock_model_config | ||
mock_vllm_config.speculative_config = None | ||
mock_vllm_config.parallel_config = MagicMock() | ||
mock_vllm_config.parallel_config.worker_cls = ( | ||
"vllm.v1.worker.tpu_worker.TPUWorker") | ||
mock_vllm_config.scheduler_config = MagicMock() | ||
|
||
# act | ||
with patch("vllm.envs.VLLM_USE_V1", use_v1): | ||
TpuPlatform.check_and_update_config(mock_vllm_config) | ||
|
||
# assert | ||
assert mock_cached_config.block_size == expected_block_size | ||
if use_v1: | ||
mock_get_min_page_size.assert_called() | ||
if initial_block_size is None: | ||
mock_get_page_size.assert_called() | ||
else: | ||
mock_get_page_size.assert_not_called() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,9 +93,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
from vllm.config import CompilationLevel | ||
|
||
cache_config = vllm_config.cache_config | ||
# For v0, the default block size is 16. | ||
if cache_config and cache_config.block_size is None: | ||
cache_config.block_size = cast(BlockSize, 16) | ||
assert cache_config is not None | ||
|
||
compilation_config = vllm_config.compilation_config | ||
|
||
# TPU only supports DYNAMO_ONCE compilation level | ||
|
@@ -118,8 +117,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
if envs.VLLM_USE_V1: | ||
from vllm.v1.attention.backends.pallas import ( | ||
PallasAttentionBackend) | ||
cache_config.block_size = PallasAttentionBackend.get_page_size( | ||
vllm_config) # type: ignore[assignment] | ||
|
||
# For v1, the default block size is calculated from vllm_config. | ||
cache_config.block_size = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add warning when the block_size set by user is not optimal? cc @bythew3i There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we know what value is optimal for the user, I think we should add. The reason I want to put back the In my benchmark testing, my |
||
cache_config.block_size | ||
or PallasAttentionBackend.get_page_size( | ||
vllm_config) # type: ignore[assignment] | ||
) | ||
|
||
min_page_size = PallasAttentionBackend.get_min_page_size( | ||
vllm_config) | ||
if min_page_size > cache_config.block_size: | ||
|
@@ -130,7 +135,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
min_page_size, | ||
) | ||
cache_config.block_size = min_page_size # type: ignore[assignment] | ||
|
||
else: | ||
# For v0, the default block size is 16. | ||
cache_config.block_size = (cache_config.block_size | ||
or cast(BlockSize, 16)) | ||
parallel_config = vllm_config.parallel_config | ||
scheduler_config = vllm_config.scheduler_config | ||
if parallel_config.worker_cls == "auto": | ||
|
Uh oh!
There was an error while loading. Please reload this page.