diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 561c3cf39e9d..e0a096a9106e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -476,67 +476,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True).long() # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks * - self.block_size) - if common_prefix_len == 0: - # Common case. - use_cascade = False - else: - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( - common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME - use_sliding_window=self.sliding_window is not None, - num_sms=self.num_sms, - ) - + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + use_cascade = common_prefix_len > 0 if use_cascade: # TODO: Optimize. cu_prefix_query_lens = torch.tensor( @@ -581,6 +525,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: int, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * self.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = FlashAttentionBackend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + use_alibi=False, # FIXME + use_sliding_window=self.sliding_window is not None, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 num_reqs = self.input_batch.num_reqs