Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ steps:
command: pytest -v -s async_engine

- label: Basic Correctness Test
command: pytest -v -s basic_correctness
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py

- label: Core Test
command: pytest -v -s core
Expand Down
6 changes: 0 additions & 6 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"""
import pytest

from vllm.attention.selector import VLLM_ATTENTION_BACKEND

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
Expand All @@ -16,7 +14,6 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -25,10 +22,7 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
attn_backend: str,
monkeypatch,
) -> None:
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
Expand Down
4 changes: 0 additions & 4 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def test_models(
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
Expand Down
18 changes: 6 additions & 12 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention()
self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
Expand Down Expand Up @@ -334,26 +334,21 @@ def _naive_attention(
prompt_lens: List[int],
scale: float,
) -> torch.Tensor:
num_tokens = query.shape[0]
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
out = _naive_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len

# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(num_tokens, -1)
return output


def _naive_masked_attention(
Expand All @@ -362,14 +357,13 @@ def _naive_masked_attention(
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, _, _ = query.shape
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min

attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
Expand Down