Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,9 +1669,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False

# No support for device type other than CUDA, AMD (experiemntal) or
# TPU (experimental) so far.
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
# Platforms must decide if they can support v1 for this model
if not current_platform.supports_v1(model_config=model_config):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
Expand Down
7 changes: 6 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None

logger = init_logger(__name__)
Expand Down Expand Up @@ -303,6 +304,10 @@ def get_device_communicator_cls(cls) -> str:
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
10 changes: 9 additions & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
FlexibleArgumentParser = None

Expand Down Expand Up @@ -371,6 +372,13 @@ def use_all_gather(cls) -> bool:
or parallel_config.distributed_executor_backend
== "external_launcher")

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return False


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
8 changes: 7 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None

logger = init_logger(__name__)
Expand Down Expand Up @@ -249,3 +250,8 @@ def fp8_dtype(cls) -> torch.dtype:
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True
8 changes: 7 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None

logger = init_logger(__name__)
Expand Down Expand Up @@ -127,3 +128,8 @@ def get_device_communicator_cls(cls) -> str:
@classmethod
def use_all_gather(cls) -> bool:
return True

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental
return True