Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/attention/__init__.py
Empty file.
155 changes: 155 additions & 0 deletions tests/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random

import pytest
import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import choose_attention_backend

NUM_BACKENDS = [1, 2, 3, 4]
HEAD_SIZES = [256]
ATTENTION_DTYPES = [torch.bfloat16]
KVCACHE_DTYPES = ["auto"]
BLOCK_SIZES = [1]


class MockUnsupportedAttentionBackend(AttentionBackend):

def __init__(self):
pass

@classmethod
def get_name(cls) -> str:
return "MOCK_UNSUPPORTED_BACKEND"

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return []

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []


class MockSupportedAttentionBackend(AttentionBackend):

def __init__(self):
pass

@classmethod
def get_name(cls) -> str:
return "MOCK_SUPPORTED_BACKEND"

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return HEAD_SIZES

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return ATTENTION_DTYPES


def get_full_qualname(cls):
"""
Returns the fully qualified class path, e.g. 'package.module.ClassName'
"""
return f"{cls.__module__}.{cls.__qualname__}"


def generate_unsupported_backend_mapping(num_backends: int) -> dict[str, str]:
backend_qualname = get_full_qualname(MockUnsupportedAttentionBackend)
return {
f"{MockUnsupportedAttentionBackend.get_name()}_{index}":
backend_qualname
for index in range(num_backends)
}


def generate_supported_backend_mapping(num_backends: int) -> dict[str, str]:
backend_qualname = get_full_qualname(MockSupportedAttentionBackend)
return {
f"{MockSupportedAttentionBackend.get_name()}_{index}": backend_qualname
for index in range(num_backends)
}


@pytest.mark.parametrize("num_backends", NUM_BACKENDS)
@pytest.mark.parametrize("arbitrary_head_size", HEAD_SIZES)
@pytest.mark.parametrize("arbitrary_dtype", ATTENTION_DTYPES)
@pytest.mark.parametrize("arbitrary_kvcache_dtype", KVCACHE_DTYPES)
@pytest.mark.parametrize("arbitrary_block_size", BLOCK_SIZES)
def test_choose_attention_backend_raises_on_no_supported_backend(
num_backends: int, arbitrary_head_size: int,
arbitrary_dtype: torch.dtype, arbitrary_kvcache_dtype: str,
arbitrary_block_size: int) -> None:

unsupported_backends = generate_unsupported_backend_mapping(num_backends)

with pytest.raises(ValueError):
choose_attention_backend(unsupported_backends, arbitrary_head_size,
arbitrary_dtype, arbitrary_kvcache_dtype,
arbitrary_block_size)


@pytest.mark.parametrize("num_backends", NUM_BACKENDS)
@pytest.mark.parametrize("arbitrary_head_size", HEAD_SIZES)
@pytest.mark.parametrize("arbitrary_dtype", ATTENTION_DTYPES)
@pytest.mark.parametrize("arbitrary_kvcache_dtype", KVCACHE_DTYPES)
@pytest.mark.parametrize("arbitrary_block_size", BLOCK_SIZES)
def test_choose_attention_backend_returns_qualname_for_supported_backend_only(
num_backends: int, arbitrary_head_size: int,
arbitrary_dtype: torch.dtype, arbitrary_kvcache_dtype: str,
arbitrary_block_size: int) -> None:
supported_backend_qual_name = get_full_qualname(
MockSupportedAttentionBackend)
supported_backends = generate_supported_backend_mapping(num_backends)
_, chosen_backend_qualname = choose_attention_backend(
supported_backends, arbitrary_head_size, arbitrary_dtype,
arbitrary_kvcache_dtype, arbitrary_block_size)
assert chosen_backend_qualname == supported_backend_qual_name


@pytest.mark.parametrize("num_backends", NUM_BACKENDS)
@pytest.mark.parametrize("arbitrary_head_size", HEAD_SIZES)
@pytest.mark.parametrize("arbitrary_dtype", ATTENTION_DTYPES)
@pytest.mark.parametrize("arbitrary_kvcache_dtype", KVCACHE_DTYPES)
@pytest.mark.parametrize("arbitrary_block_size", BLOCK_SIZES)
def test_choose_attention_backend_returns_supported_backend_qualname(
num_backends: int, arbitrary_head_size: int,
arbitrary_dtype: torch.dtype, arbitrary_kvcache_dtype: str,
arbitrary_block_size: int) -> None:
supported_backend_qual_name = get_full_qualname(
MockSupportedAttentionBackend)
unsupported_backends = generate_unsupported_backend_mapping(num_backends)
supported_backends = generate_supported_backend_mapping(num_backends)
all_backends = supported_backends | unsupported_backends
_, chosen_backend_qualname = choose_attention_backend(
all_backends, arbitrary_head_size, arbitrary_dtype,
arbitrary_kvcache_dtype, arbitrary_block_size)

assert chosen_backend_qualname == supported_backend_qual_name


@pytest.mark.parametrize("num_backends", NUM_BACKENDS)
@pytest.mark.parametrize("arbitrary_head_size", HEAD_SIZES)
@pytest.mark.parametrize("arbitrary_dtype", ATTENTION_DTYPES)
@pytest.mark.parametrize("arbitrary_kvcache_dtype", KVCACHE_DTYPES)
@pytest.mark.parametrize("arbitrary_block_size", BLOCK_SIZES)
def test_choose_attention_backend_forces_backend_via_env(
num_backends: int, arbitrary_head_size: int,
arbitrary_dtype: torch.dtype, arbitrary_kvcache_dtype: str,
arbitrary_block_size: int, monkeypatch) -> None:
supported_backends = generate_supported_backend_mapping(num_backends)

# Force an arbitrary supported backend
arbitrary_backend_name, _ = random.choice(list(supported_backends.items()))
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", arbitrary_backend_name)

chosen_backend_name, _ = choose_attention_backend(supported_backends,
arbitrary_head_size,
arbitrary_dtype,
arbitrary_kvcache_dtype,
arbitrary_block_size)
assert chosen_backend_name == arbitrary_backend_name
3 changes: 2 additions & 1 deletion tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def test_env(
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
expected = ("TRITON_UNIFIED_ATTENTION_V1"
if use_v1 else "ROCM_FLASH")
assert backend.get_name() == expected

elif device == "cuda":
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# Test standard ROCm attention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "TRITON_ATTN_VLLM_V1")
or backend.get_name() == "TRITON_UNIFIED_ATTENTION_V1")

# MLA test for deepseek related

Expand Down
3 changes: 2 additions & 1 deletion tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
_Backend.TRITON_UNIFIED_ATTENTION_V1, _Backend.FLEX_ATTENTION,
_Backend.TREE_ATTN
]

# Remove flashinfer from the list if it's not available
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN_VLLM_V1:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TRITON_UNIFIED_ATTN_VLLM_V1:
"vllm.v1.attention.backends.triton_attn.TritonUnifiedAttentionBackend",
_Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
}
Expand Down
150 changes: 141 additions & 9 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,45 @@ class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
kv_cache_dtype: bool
block_size: bool
device_capabality: bool
reasons: list[str]

def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
return (self.can_import and self.head_size and self.dtype
and self.kv_cache_dtype and self.block_size
and self.device_capabality)

def get_humanized_reasons(self) -> str:
return "\n".join(f"{i+1}. {item}"
for i, item in enumerate(self.reasons))


def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str = "",
block_size: int = 0,
*,
allow_import_error: bool = True,
) -> _IsSupported:
reasons = []
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise

return _IsSupported(can_import=False, head_size=False, dtype=False)
reasons.append("Could not import attention backend")
return _IsSupported(can_import=False,
head_size=False,
dtype=False,
kv_cache_dtype=False,
block_size=False,
device_capabality=False,
reasons=reasons)

assert isinstance(attn_backend, type)

Expand All @@ -117,7 +136,8 @@ def is_attn_backend_supported(
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
except Exception as e:
reasons.append(str(e))
is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
Expand All @@ -130,11 +150,41 @@ def is_attn_backend_supported(
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")

return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
is_kv_cache_dtype_supported = True
if validate_kv_cache_dtype := getattr(attn_backend,
"validate_kv_cache_dtype", None):
try:
validate_kv_cache_dtype(kv_cache_dtype)
except Exception as e:
reasons.append(str(e))
is_kv_cache_dtype_supported = False

is_device_capabality_supported = True
if validate_device_capabality := getattr(attn_backend,
"validate_device_capabality",
None):
try:
validate_device_capabality()
except Exception as e:
reasons.append(str(e))
is_device_capabality_supported = False

is_block_size_supported = True
if validate_block_size := getattr(attn_backend, "validate_block_size",
None):
try:
validate_block_size(block_size)
except Exception as e:
reasons.append(str(e))
is_block_size_supported = False

return _IsSupported(can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
kv_cache_dtype=is_kv_cache_dtype_supported,
block_size=is_block_size_supported,
device_capabality=is_device_capabality_supported,
reasons=reasons)


def get_attn_backend(
Expand Down Expand Up @@ -238,3 +288,85 @@ def global_force_attn_backend_context_manager(
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)


def choose_attention_backend(
backend_to_qualname: dict[str, str],
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
block_size: int,
) -> tuple[str, str]:
"""
Selects and returns a suitable attention backend for the given
configuration.

The function will first attempt to select a backend based on the
environmentvariable `VLLM_ATTENTION_BACKEND`, if it is set. If the
forced backend is either invalid or not supported for the given
configuration, it falls back to automatically selecting the first
available supported backend from `backend_to_qualname`.

Parameters:
backend_to_qualname (dict[str, str]): Mapping from backend names to
their qualified names.
head_size (int): Size of the attention head.
dtype (torch.dtype): Data type for computation.
kv_cache_dtype (str): Data type of the key-value cache ("auto" or
specific type).
block_size (int): Block size to use for the backend.

Returns:
tuple[str, str]: (backend name, backend qualified name) of the
selected backend.

Raises:
ValueError: If no supported backend is found for the given
configuration.
"""
maybe_forced_backend = envs.VLLM_ATTENTION_BACKEND
if maybe_forced_backend:
if maybe_forced_backend not in backend_to_qualname:
message = f"VLLM_ATTENTION_BACKEND is set, but " \
f"{maybe_forced_backend} is not a valid " \
"attention backend."

logger.warning(message)
else:
qualified_name = backend_to_qualname[maybe_forced_backend]
if is_supported := is_attn_backend_supported(
qualified_name,
head_size,
dtype,
kv_cache_dtype,
block_size,
allow_import_error=False):
message = f"{maybe_forced_backend} has been forced. " \
f"Unset VLLM_ATTENTION_BACKEND to enable " \
"auto-selection."

logger.warning(message)
return maybe_forced_backend, qualified_name

else:
failure_reasons = is_supported.get_humanized_reasons()
message = f"Tried to force {maybe_forced_backend}, " \
"but it is not supported with the given " \
"configuration for the following reasons: " \
f"\n{failure_reasons}."

logger.warning(message)

logger.info("Reverting back to auto-selection for attention backend.")

for backend_name, qualname in backend_to_qualname.items():
if is_attn_backend_supported(qualname,
head_size,
dtype,
kv_cache_dtype,
block_size,
allow_import_error=False):
return backend_name, qualname

raise ValueError(
"No attention backend supports the current configuration.")
Loading