Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 89 additions & 61 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,67 +476,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
self.device, non_blocking=True).long()

# Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
self.block_size)
if common_prefix_len == 0:
# Common case.
use_cascade = False
else:
# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
# different requests) and treat them as if they are from the same
# request. Then, we use bi-directional attention to process the
# common prefix in the KV cache. Importantly, this means that the
# first kernel does not do any masking.

# Consider the following example:
# Request 1's input query: [D, E, X]
# Request 1's kv cache: [A, B, C, D, E, X]
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
# Request 2's input query: [E, Y]
# Request 2's kv cache: [A, B, C, D, E, Y]
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

# If we use [A, B, C, D, E] as the common prefix, then the
# first kernel will compute the bi-directional attention between
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
# However, this is wrong because D in Request 1 should not attend to
# E in the common prefix (i.e., we need masking).
# To avoid this, [A, B, C, D] should be the common prefix.
# That is, the common prefix should be capped by the minimum
# num_computed_tokens among the requests, and plus one to include
# the first token of the query.

# In practice, we use [A, B, C] as the common prefix, instead of
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
# num_computed_tokens, without plus one).
# This is because of an implementation detail: We want to always
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
common_prefix_len = min(
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = FlashAttentionBackend.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
use_alibi=False, # FIXME
use_sliding_window=self.sliding_window is not None,
num_sms=self.num_sms,
)

common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
Expand Down Expand Up @@ -581,6 +525,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices

def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
num_common_prefix_blocks: int,
) -> int:
"""Compute the length of the common prefix for cascade attention.

NOTE(woosuk): The common prefix length returned by this function
represents the length used specifically for cascade attention, not the
actual number of tokens shared between requests. When cascade attention
is disabled (use_cascade=False), this function returns 0 even if
requests share common tokens. Additionally, the common prefix length is
truncated to a multiple of the block size and may be further truncated
due to implementation details explained below.

Args:
num_scheduled_tokens: Number of tokens scheduled per request.
num_common_prefix_blocks: Number of shared KV cache blocks.

Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len = num_common_prefix_blocks * self.block_size
if common_prefix_len == 0:
# Common case.
return 0

# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
# different requests) and treat them as if they are from the same
# request. Then, we use bi-directional attention to process the
# common prefix in the KV cache. Importantly, this means that the
# first kernel does not do any masking.

# Consider the following example:
# Request 1's input query: [D, E, X]
# Request 1's kv cache: [A, B, C, D, E, X]
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
# Request 2's input query: [E, Y]
# Request 2's kv cache: [A, B, C, D, E, Y]
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

# If we use [A, B, C, D, E] as the common prefix, then the
# first kernel will compute the bi-directional attention between
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
# However, this is wrong because D in Request 1 should not attend to
# E in the common prefix (i.e., we need masking).
# To avoid this, [A, B, C, D] should be the common prefix.
# That is, the common prefix should be capped by the minimum
# num_computed_tokens among the requests, and plus one to include
# the first token of the query.

# In practice, we use [A, B, C] as the common prefix, instead of
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
# num_computed_tokens, without plus one).
# This is because of an implementation detail: We want to always
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min(
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = FlashAttentionBackend.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
use_alibi=False, # FIXME
use_sliding_window=self.sliding_window is not None,
num_sms=self.num_sms,
)
return common_prefix_len if use_cascade else 0

def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
num_reqs = self.input_batch.num_reqs
Expand Down