Skip to content

Commit 149389c

Browse files
joerundeyangw-dev
authored andcommitted
[Hardware] add platform-specific request validation api (vllm-project#16291)
Signed-off-by: Joe Runde <[email protected]> Signed-off-by: Yang Wang <[email protected]>
1 parent 479971b commit 149389c

File tree

9 files changed

+38
-41
lines changed

9 files changed

+38
-41
lines changed

vllm/platforms/cpu.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,3 @@ def get_device_communicator_cls(cls) -> str:
180180
Get device specific communicator class for distributed communication.
181181
"""
182182
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
183-
184-
@classmethod
185-
def supports_structured_output(cls) -> bool:
186-
return True

vllm/platforms/cuda.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,6 @@ def supports_fp8(cls) -> bool:
308308
def supports_v1(cls, model_config: ModelConfig) -> bool:
309309
return True
310310

311-
@classmethod
312-
def supports_structured_output(cls) -> bool:
313-
return True
314-
315311
@classmethod
316312
def use_custom_allreduce(cls) -> bool:
317313
return True

vllm/platforms/hpu.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,3 @@ def get_punica_wrapper(cls) -> str:
9292
@classmethod
9393
def get_device_communicator_cls(cls) -> str:
9494
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
95-
96-
@classmethod
97-
def supports_structured_output(cls) -> bool:
98-
return True

vllm/platforms/interface.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
32
import enum
43
import platform
54
import random
@@ -9,14 +8,21 @@
98
import numpy as np
109
import torch
1110

11+
from vllm.inputs import PromptType
1212
from vllm.logger import init_logger
1313

1414
if TYPE_CHECKING:
1515
from vllm.config import ModelConfig, VllmConfig
16+
from vllm.lora.request import LoRARequest
17+
from vllm.pooling_params import PoolingParams
18+
from vllm.sampling_params import SamplingParams
1619
from vllm.utils import FlexibleArgumentParser
1720
else:
1821
ModelConfig = None
1922
VllmConfig = None
23+
LoRARequest = None
24+
PoolingParams = None
25+
SamplingParams = None
2026
FlexibleArgumentParser = None
2127

2228
logger = init_logger(__name__)
@@ -379,20 +385,21 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
379385
"""
380386
return False
381387

382-
@classmethod
383-
def supports_structured_output(cls) -> bool:
384-
"""
385-
Returns whether the current platform can support structured output.
386-
"""
387-
return False
388-
389388
@classmethod
390389
def use_custom_allreduce(cls) -> bool:
391390
"""
392391
Returns if custom allreduce is supported on the current platform
393392
"""
394393
return False
395394

395+
@classmethod
396+
def validate_request(
397+
cls,
398+
prompt: PromptType,
399+
params: Union[SamplingParams, PoolingParams],
400+
) -> None:
401+
"""Raises if this request is unsupported on this platform"""
402+
396403

397404
class UnspecifiedPlatform(Platform):
398405
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/neuron.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,3 @@ def get_device_communicator_cls(cls) -> str:
6767
@classmethod
6868
def use_all_gather(cls) -> bool:
6969
return True
70-
71-
@classmethod
72-
def supports_structured_output(cls) -> bool:
73-
return True

vllm/platforms/rocm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,6 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
303303
# V1 support on AMD gpus is experimental
304304
return True
305305

306-
@classmethod
307-
def supports_structured_output(cls) -> bool:
308-
return True
309-
310306
@classmethod
311307
def use_custom_allreduce(cls) -> bool:
312308
# We only enable custom allreduce for MI300 series

vllm/platforms/tpu.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import TYPE_CHECKING, Optional
3+
from typing import TYPE_CHECKING, Optional, Union
44

55
import torch
66

77
import vllm.envs as envs
8+
from vllm.inputs import PromptType
89
from vllm.logger import init_logger
910

1011
from .interface import Platform, PlatformEnum, _Backend
1112

1213
if TYPE_CHECKING:
1314
from vllm.config import ModelConfig, VllmConfig
15+
from vllm.lora.request import LoRARequest
16+
from vllm.pooling_params import PoolingParams
17+
from vllm.sampling_params import SamplingParams
1418
else:
1519
ModelConfig = None
1620
VllmConfig = None
21+
LoRARequest = None
22+
PoolingParams = None
23+
SamplingParams = None
1724

1825
logger = init_logger(__name__)
1926

@@ -135,6 +142,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
135142
return True
136143

137144
@classmethod
138-
def supports_structured_output(cls) -> bool:
139-
# Structured output is not supported on TPU.
140-
return False
145+
def validate_request(
146+
cls,
147+
prompt: PromptType,
148+
params: Union[SamplingParams, PoolingParams],
149+
) -> None:
150+
"""Raises if this request is unsupported on this platform"""
151+
if isinstance(params,
152+
SamplingParams) and params.guided_decoding is not None:
153+
raise ValueError("Structured output is not supported on "
154+
f"{cls.device_name}.")

vllm/platforms/xpu.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,3 @@ def device_support_bf16(cls) -> bool:
140140
@classmethod
141141
def get_device_communicator_cls(cls) -> str:
142142
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
143-
144-
@classmethod
145-
def supports_structured_output(cls) -> bool:
146-
return True

vllm/v1/engine/processor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,6 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
141141
else:
142142
params.guided_decoding.backend = engine_level_backend
143143

144-
from vllm.platforms import current_platform
145-
if not current_platform.supports_structured_output():
146-
raise ValueError("Structured output is not supported on "
147-
f"{current_platform.device_name}.")
148-
149144
# Request content validation
150145
if engine_level_backend.startswith("xgrammar"):
151146
# xgrammar with no fallback
@@ -187,6 +182,11 @@ def process_inputs(
187182
# TODO(woosuk): Support pooling models.
188183
# TODO(woosuk): Support encoder-decoder models.
189184

185+
from vllm.platforms import current_platform
186+
current_platform.validate_request(
187+
prompt=prompt,
188+
params=params,
189+
)
190190
self._validate_lora(lora_request)
191191
self._validate_params(params)
192192
if priority != 0:

0 commit comments

Comments
 (0)