Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tests/v1/determinism/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
Expand Down Expand Up @@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
)

prompt = "the capital of france is"
Expand Down Expand Up @@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
Expand Down Expand Up @@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
Expand Down Expand Up @@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
# Enable for MOE models
# enable_expert_parallel=True,
)
5 changes: 4 additions & 1 deletion tests/v1/determinism/test_online_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
}

tp_size = os.getenv("VLLM_TP_SIZE", "1")
server_args: list[str] = []
server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
]
if tp_size:
server_args += ["-tp", tp_size]

Expand Down
1 change: 1 addition & 0 deletions tests/v1/determinism/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_MLA",
]

if has_flashinfer():
Expand Down
36 changes: 36 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
Expand Down Expand Up @@ -246,6 +247,24 @@ def __init__(
else:
self.attn_backend = attn_backend

# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False

impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
Expand Down Expand Up @@ -623,6 +642,23 @@ def __init__(
use_mla=True,
use_sparse=use_sparse,
)

if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False

impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
error = (
Expand Down