@@ -476,67 +476,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
476476 self .device , non_blocking = True ).long ()
477477
478478 # Prepare for cascade attention if needed.
479- common_prefix_len = (scheduler_output .num_common_prefix_blocks *
480- self .block_size )
481- if common_prefix_len == 0 :
482- # Common case.
483- use_cascade = False
484- else :
485- # NOTE(woosuk): Cascade attention uses two attention kernels: one
486- # for the common prefix and the other for the rest. For the first
487- # kernel, we concatenate all the query tokens (possibly from
488- # different requests) and treat them as if they are from the same
489- # request. Then, we use bi-directional attention to process the
490- # common prefix in the KV cache. Importantly, this means that the
491- # first kernel does not do any masking.
492-
493- # Consider the following example:
494- # Request 1's input query: [D, E, X]
495- # Request 1's kv cache: [A, B, C, D, E, X]
496- # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
497- # Request 2's input query: [E, Y]
498- # Request 2's kv cache: [A, B, C, D, E, Y]
499- # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
500-
501- # If we use [A, B, C, D, E] as the common prefix, then the
502- # first kernel will compute the bi-directional attention between
503- # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
504- # However, this is wrong because D in Request 1 should not attend to
505- # E in the common prefix (i.e., we need masking).
506- # To avoid this, [A, B, C, D] should be the common prefix.
507- # That is, the common prefix should be capped by the minimum
508- # num_computed_tokens among the requests, and plus one to include
509- # the first token of the query.
510-
511- # In practice, we use [A, B, C] as the common prefix, instead of
512- # [A, B, C, D] (i.e., the common prefix is capped by the minimum
513- # num_computed_tokens, without plus one).
514- # This is because of an implementation detail: We want to always
515- # use two kernels for cascade attention. Let's imagine:
516- # Request 3's input query: [D]
517- # Request 3's kv cache: [A, B, C, D]
518- # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
519- # If we use [A, B, C, D] as the common prefix for Request 1-3,
520- # then Request 3 will be processed only by the first kernel,
521- # and the second kernel will get an empty input. While this is not
522- # a fundamental problem, our current implementation does not support
523- # this case.
524- common_prefix_len = min (
525- common_prefix_len ,
526- self .input_batch .num_computed_tokens_cpu [:num_reqs ].min ())
527- # common_prefix_len should be a multiple of the block size.
528- common_prefix_len = (common_prefix_len // self .block_size *
529- self .block_size )
530- use_cascade = FlashAttentionBackend .use_cascade_attention (
531- common_prefix_len = common_prefix_len ,
532- query_lens = num_scheduled_tokens ,
533- num_query_heads = self .num_query_heads ,
534- num_kv_heads = self .num_kv_heads ,
535- use_alibi = False , # FIXME
536- use_sliding_window = self .sliding_window is not None ,
537- num_sms = self .num_sms ,
538- )
539-
479+ common_prefix_len = self ._compute_cascade_attn_prefix_len (
480+ num_scheduled_tokens ,
481+ scheduler_output .num_common_prefix_blocks ,
482+ )
483+ use_cascade = common_prefix_len > 0
540484 if use_cascade :
541485 # TODO: Optimize.
542486 cu_prefix_query_lens = torch .tensor (
@@ -581,6 +525,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
581525 logits_indices = query_start_loc [1 :] - 1
582526 return attn_metadata , logits_indices
583527
528+ def _compute_cascade_attn_prefix_len (
529+ self ,
530+ num_scheduled_tokens : np .ndarray ,
531+ num_common_prefix_blocks : int ,
532+ ) -> int :
533+ """Compute the length of the common prefix for cascade attention.
534+
535+ NOTE(woosuk): The common prefix length returned by this function
536+ represents the length used specifically for cascade attention, not the
537+ actual number of tokens shared between requests. When cascade attention
538+ is disabled (use_cascade=False), this function returns 0 even if
539+ requests share common tokens. Additionally, the common prefix length is
540+ truncated to a multiple of the block size and may be further truncated
541+ due to implementation details explained below.
542+
543+ Args:
544+ num_scheduled_tokens: Number of tokens scheduled per request.
545+ num_common_prefix_blocks: Number of shared KV cache blocks.
546+
547+ Returns:
548+ int: Length of common prefix in tokens.
549+ """
550+ common_prefix_len = num_common_prefix_blocks * self .block_size
551+ if common_prefix_len == 0 :
552+ # Common case.
553+ return 0
554+
555+ # NOTE(woosuk): Cascade attention uses two attention kernels: one
556+ # for the common prefix and the other for the rest. For the first
557+ # kernel, we concatenate all the query tokens (possibly from
558+ # different requests) and treat them as if they are from the same
559+ # request. Then, we use bi-directional attention to process the
560+ # common prefix in the KV cache. Importantly, this means that the
561+ # first kernel does not do any masking.
562+
563+ # Consider the following example:
564+ # Request 1's input query: [D, E, X]
565+ # Request 1's kv cache: [A, B, C, D, E, X]
566+ # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
567+ # Request 2's input query: [E, Y]
568+ # Request 2's kv cache: [A, B, C, D, E, Y]
569+ # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
570+
571+ # If we use [A, B, C, D, E] as the common prefix, then the
572+ # first kernel will compute the bi-directional attention between
573+ # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
574+ # However, this is wrong because D in Request 1 should not attend to
575+ # E in the common prefix (i.e., we need masking).
576+ # To avoid this, [A, B, C, D] should be the common prefix.
577+ # That is, the common prefix should be capped by the minimum
578+ # num_computed_tokens among the requests, and plus one to include
579+ # the first token of the query.
580+
581+ # In practice, we use [A, B, C] as the common prefix, instead of
582+ # [A, B, C, D] (i.e., the common prefix is capped by the minimum
583+ # num_computed_tokens, without plus one).
584+ # This is because of an implementation detail: We want to always
585+ # use two kernels for cascade attention. Let's imagine:
586+ # Request 3's input query: [D]
587+ # Request 3's kv cache: [A, B, C, D]
588+ # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
589+ # If we use [A, B, C, D] as the common prefix for Request 1-3,
590+ # then Request 3 will be processed only by the first kernel,
591+ # and the second kernel will get an empty input. While this is not
592+ # a fundamental problem, our current implementation does not support
593+ # this case.
594+ num_reqs = len (num_scheduled_tokens )
595+ common_prefix_len = min (
596+ common_prefix_len ,
597+ self .input_batch .num_computed_tokens_cpu [:num_reqs ].min ())
598+ # common_prefix_len should be a multiple of the block size.
599+ common_prefix_len = (common_prefix_len // self .block_size *
600+ self .block_size )
601+ use_cascade = FlashAttentionBackend .use_cascade_attention (
602+ common_prefix_len = common_prefix_len ,
603+ query_lens = num_scheduled_tokens ,
604+ num_query_heads = self .num_query_heads ,
605+ num_kv_heads = self .num_kv_heads ,
606+ use_alibi = False , # FIXME
607+ use_sliding_window = self .sliding_window is not None ,
608+ num_sms = self .num_sms ,
609+ )
610+ return common_prefix_len if use_cascade else 0
611+
584612 def _calc_mrope_positions (self , scheduler_output : "SchedulerOutput" ):
585613 mrope_pos_ptr = 0
586614 num_reqs = self .input_batch .num_reqs
0 commit comments