diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 4311547baccf..fc953a66f082 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, + # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported @@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): gpu_memory_utilization=0.9, max_model_len=2048, dtype="bfloat16", - enable_prefix_caching=False, ) prompt = "the capital of france is" @@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", @@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", @@ -928,7 +925,6 @@ def LLM_with_max_seqs( max_model_len=max_model_len, dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), - enable_prefix_caching=False, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index d74b435797f8..5e3b99736494 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( } tp_size = os.getenv("VLLM_TP_SIZE", "1") - server_args: list[str] = [] + server_args: list[str] = [ + "--max-model-len=8192", + "--max-num-seqs=32", + ] if tp_size: server_args += ["-tp", tp_size] diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 0d7da107728b..6aab50cf84ab 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -17,6 +17,7 @@ BACKENDS: list[str] = [ "FLASH_ATTN", + "TRITON_MLA", ] if has_flashinfer(): diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da5a62617129..f1bbe2d9e4da 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -25,6 +25,7 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, UnquantizedLinearMethod, @@ -246,6 +247,24 @@ def __init__( else: self.attn_backend = attn_backend + # prefix caching + batch invariance is currently not supported for + # FLASHINFER and TRITON_MLA. + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "FLASHINFER" + or self.attn_backend.get_name() == "TRITON_MLA" + ) + ): + logger.warning_once( + "Disabling prefix caching for FLASHINFER/TRITON_MLA " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls( num_heads, @@ -623,6 +642,23 @@ def __init__( use_mla=True, use_sparse=use_sparse, ) + + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "TRITON_MLA" + or self.attn_backend.get_name() == "FLASHINFER" + ) + ): + logger.warning_once( + "Disabling prefix caching for TRITON_MLA / FLASHINFER " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 4154122636dc..4cab47f4192a 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1006,11 +1006,11 @@ def override_envs_for_invariance(): "FLASH_ATTN", # best supported backend "FLASHINFER", "FLASH_ATTN_MLA", + "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 - # "TRITON_MLA", ] if curr_attn_backend not in supported_backends: error = (