diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 99c68a863f59..3987986f1786 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -75,11 +75,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # Maximum query length in the batch. max_query_len: Optional[int] - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: Optional[int] + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -140,7 +137,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - decode_query_len=0, + max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -172,7 +169,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - decode_query_len=self.decode_query_len, + max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -256,9 +253,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: - decode_query_len = max(decode_query_lens) + max_decode_query_len = max(decode_query_lens) else: - decode_query_len = 1 + max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -304,7 +301,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - decode_query_len=decode_query_len, + max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc,