Skip to content

Commit eee6378

Browse files
committed
[V1] TPU support
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 24b0205 commit eee6378

File tree

16 files changed

+1759
-248
lines changed

16 files changed

+1759
-248
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ repos:
8989
name: Suggestion
9090
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
9191
language: system
92-
verbose: true
92+
verbose: true

examples/offline_inference/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
"The future of AI is",
99
]
1010
# Create a sampling params object.
11-
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
11+
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
1212

1313
# Create an LLM.
14-
llm = LLM(model="facebook/opt-125m")
14+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
1515
# Generate texts from the prompts. The output is a list of RequestOutput objects
1616
# that contain the prompt, generated text, and other information.
1717
outputs = llm.generate(prompts, sampling_params)
1818
# Print the outputs.
1919
for output in outputs:
2020
prompt = output.prompt
2121
generated_text = output.outputs[0].text
22-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
22+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/entrypoints/openai/test_accuracy.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,21 @@ def run_test(more_args):
6666
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
6767

6868

69-
@pytest.mark.skipif(not current_platform.is_cuda(),
70-
reason="V1 currently only supported on CUDA")
69+
@pytest.mark.skipif(not current_platform.is_cuda()
70+
and not current_platform.is_tpu(),
71+
reason="V1 currently only supported on CUDA and TPU")
7172
def test_lm_eval_accuracy_v1_engine(monkeypatch):
7273
"""Run with the V1 Engine."""
7374

7475
with monkeypatch.context() as m:
7576
m.setenv("VLLM_USE_V1", "1")
76-
run_test([])
77+
more_args = []
78+
79+
# Limit compilation time for V1
80+
if current_platform.is_tpu():
81+
more_args = ["--max-num-seqs", "64"]
82+
83+
run_test(more_args)
7784

7885

7986
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

tools/mypy.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ run_mypy vllm/plugins
3434
run_mypy vllm/prompt_adapter
3535
run_mypy vllm/spec_decode
3636
run_mypy vllm/worker
37-
run_mypy vllm/v1
37+
run_mypy vllm/v1

vllm/platforms/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
135135
else:
136136
if envs.VLLM_USE_V1:
137137
parallel_config.worker_cls = \
138-
"vllm.v1.worker.gpu_worker.Worker"
138+
"vllm.v1.worker.gpu_worker.GPUWorker"
139139
else:
140140
parallel_config.worker_cls = "vllm.worker.worker.Worker"
141141

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class _Backend(enum.Enum):
3232
FLASHINFER = enum.auto()
3333
HPU_ATTN = enum.auto()
3434
PALLAS = enum.auto()
35+
PALLAS_VLLM_V1 = enum.auto()
3536
IPEX = enum.auto()
3637
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
3738
NO_ATTENTION = enum.auto()

vllm/platforms/tpu.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
import vllm.envs as envs
56
from vllm.logger import init_logger
67

78
from .interface import Platform, PlatformEnum, _Backend
@@ -30,10 +31,16 @@ class TpuPlatform(Platform):
3031
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3132
dtype: torch.dtype, kv_cache_dtype: Optional[str],
3233
block_size: int, use_v1: bool) -> str:
33-
if selected_backend != _Backend.PALLAS:
34+
if (selected_backend != _Backend.PALLAS
35+
and selected_backend != _Backend.PALLAS_VLLM_V1):
3436
logger.info("Cannot use %s backend on TPU.", selected_backend)
35-
logger.info("Using Pallas backend.")
36-
return "vllm.attention.backends.pallas.PallasAttentionBackend"
37+
38+
if use_v1:
39+
logger.info("Using Pallas V1 backend.")
40+
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
41+
else:
42+
logger.info("Using Pallas backend.")
43+
return "vllm.attention.backends.pallas.PallasAttentionBackend"
3744

3845
@classmethod
3946
def get_device_name(cls, device_id: int = 0) -> str:
@@ -45,7 +52,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
4552

4653
@classmethod
4754
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
48-
return True
55+
return not envs.VLLM_USE_V1
4956

5057
@classmethod
5158
def inference_mode(cls):
@@ -60,22 +67,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6067
cache_config.block_size = 16
6168

6269
compilation_config = vllm_config.compilation_config
63-
if compilation_config.level == CompilationLevel.NO_COMPILATION:
64-
# TPU does not support NO_COMPILATION
70+
71+
# TPU only supports DYNAMO_ONCE compilation level
72+
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
73+
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
6574
compilation_config.level = CompilationLevel.DYNAMO_ONCE
66-
assert compilation_config.level < CompilationLevel.PIECEWISE,\
67-
"TPU does not support Inductor."
6875

6976
if compilation_config.backend == "":
7077
compilation_config.backend = "openxla"
7178

7279
assert vllm_config.speculative_config is None, \
7380
"TPU does not support speculative decoding"
7481

75-
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
76-
"Chunked prefill is not yet supported for TPU backend")
77-
assert not vllm_config.speculative_config, (
78-
"Speculative decoding is not yet supported for TPU backend")
7982
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
8083
logger.warning(
8184
"The TPU backend currently does not support %s. "
@@ -85,8 +88,34 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8588
parallel_config = vllm_config.parallel_config
8689
scheduler_config = vllm_config.scheduler_config
8790
if parallel_config.worker_cls == "auto":
88-
if scheduler_config.is_multi_step:
91+
if envs.VLLM_USE_V1:
8992
parallel_config.worker_cls = \
90-
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
93+
"vllm.v1.worker.tpu_worker.TPUWorker"
9194
else:
92-
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
95+
if scheduler_config.is_multi_step:
96+
parallel_config.worker_cls = \
97+
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
98+
else:
99+
parallel_config.worker_cls = \
100+
"vllm.worker.tpu_worker.TPUWorker"
101+
102+
# Adjust scheduler config for V1
103+
# TODO: Add support for these
104+
if envs.VLLM_USE_V1:
105+
if vllm_config.cache_config.enable_prefix_caching:
106+
logger.warning("[V1][TPU] Disable prefix caching")
107+
vllm_config.cache_config.enable_prefix_caching = False
108+
109+
if vllm_config.scheduler_config.chunked_prefill_enabled:
110+
logger.warning("[V1][TPU] Disable chunked prefill")
111+
vllm_config.scheduler_config.chunked_prefill_enabled = False
112+
113+
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
114+
"Chunked prefill is not yet supported for TPU backend")
115+
assert not vllm_config.speculative_config, (
116+
"Speculative decoding is not yet supported for TPU backend")
117+
118+
@classmethod
119+
def is_pin_memory_available(cls):
120+
logger.warning("Pin memory is not supported on TPU.")
121+
return False

0 commit comments

Comments
 (0)