Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/v1/e2e/test_cascade_attention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# SPDX-License-Identifier: Apache-2.0

import pytest

from vllm import LLM, SamplingParams

from ...utils import fork_new_process_for_each_test


def test_cascade_attention(example_system_message, monkeypatch):
@fork_new_process_for_each_test
@pytest.mark.parametrize("attn_backend",
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
Expand Down
13 changes: 10 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,10 +1520,17 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No FlashInfer or XFormers so far.
# No XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def get_kv_cache_shape(
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)


@dataclass
class FlashAttentionMetadata:
Expand Down Expand Up @@ -402,6 +398,9 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
)
return attn_metadata

def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)


class FlashAttentionImpl(AttentionImpl):

Expand Down
Loading