Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def __post_init__(self):
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")

if current_platform.is_cuda_alike() or current_platform.is_xpu():
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
def support_hybrid_kv_cache(cls) -> bool:
return True

@classmethod
def support_static_graph_mode(cls) -> bool:
return True


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,13 @@ def support_hybrid_kv_cache(cls) -> bool:
"""
return False

@classmethod
def support_static_graph_mode(cls) -> bool:
"""
Returns if the graph mode is supported by the current platform.
"""
return False

@classmethod
def use_sync_weight_loader(cls) -> bool:
"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

@classmethod
def support_static_graph_mode(cls) -> bool:
return True
13 changes: 7 additions & 6 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# lazy import to avoid circular import
from vllm.config import CompilationLevel, CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
"cudagraphs. Fallback to cudagraph_mode=NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
"CUDA graph mode should be NONE on XPU"

if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
Expand Down Expand Up @@ -169,6 +166,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def support_hybrid_kv_cache(cls) -> bool:
return True

@classmethod
def support_static_graph_mode(cls) -> bool:
return False

@classmethod
def is_pin_memory_available(cls):
return True
Expand Down