diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 38a47a846df7..69a164bbc6a8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1669,9 +1669,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False - # No support for device type other than CUDA, AMD (experiemntal) or - # TPU (experimental) so far. - if not (current_platform.is_cuda_alike() or current_platform.is_tpu()): + # Platforms must decide if they can support v1 for this model + if not current_platform.supports_v1(model_config=model_config): _raise_or_fallback( feature_name=f"device type={current_platform.device_type}", recommend_to_remove=False) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bb77318092fc..ca8a2d2640ec 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -20,8 +20,9 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -303,6 +304,10 @@ def get_device_communicator_cls(cls) -> str: def supports_fp8(cls) -> bool: return cls.has_device_capability(89) + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9981deee39b7..36db70681a19 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -12,9 +12,10 @@ from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig from vllm.utils import FlexibleArgumentParser else: + ModelConfig = None VllmConfig = None FlexibleArgumentParser = None @@ -371,6 +372,13 @@ def use_all_gather(cls) -> bool: or parallel_config.distributed_executor_backend == "external_launcher") + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + """Returns whether the current platform can support v1 for the supplied + model configuration. + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ee708f5961df..d196e24ac7ac 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -12,8 +12,9 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -249,3 +250,8 @@ def fp8_dtype(cls) -> torch.dtype: return torch.float8_e4m3fnuz else: return torch.float8_e4m3fn + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + # V1 support on AMD gpus is experimental + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 073d46c25d57..43d3044cb93e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -10,8 +10,9 @@ from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -127,3 +128,8 @@ def get_device_communicator_cls(cls) -> str: @classmethod def use_all_gather(cls) -> bool: return True + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + # V1 support on TPU is experimental + return True