From 91f51aa08a63b05b439155536bbe6c63a5120c66 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 27 Feb 2025 16:41:40 +0000 Subject: [PATCH 01/24] flashinfer minimally working --- vllm/platforms/cuda.py | 6 +- vllm/v1/attention/backends/flashinfer.py | 543 +++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 41 +- 3 files changed, 576 insertions(+), 14 deletions(-) create mode 100755 vllm/v1/attention/backends/flashinfer.py diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bf425b89132e..29ba311030ed 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -154,8 +154,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_v1: - logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + # logger.info("Using Flash Attention backend on V1 engine.") + # return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + logger.info("Using FlashInfer backend on V1 engine.") + return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" if use_mla: logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py new file mode 100755 index 000000000000..48b968d698ac --- /dev/null +++ b/vllm/v1/attention/backends/flashinfer.py @@ -0,0 +1,543 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +try: + from flashinfer import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import numpy as np +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cdiv + +if current_platform.is_cuda(): + from vllm.vllm_flash_attn import flash_attn_varlen_func + +logger = init_logger(__name__) + +FLASHINFER_WORKSPACE_BUFFER = None +FLASHINFER_WRAPPER = None + + +class FlashInferBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_name() -> str: + return "FLASHINFER_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + return use_cascade_attention(*args, **kwargs) + + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = vllm_config.compilation_config.static_forward_context + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + assert isinstance(layer, Attention) + + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._workspace_buffer = None + self._wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_wrapper(self): + if self._wrapper is None: + self._wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._wrapper + + def begin_forward(self, attn_metadata: "FlashInferMetadata"): + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + attn_metadata.decode_wrapper = self._get_decode_wrapper() + attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata: + + num_actual_tokens: int # Number of tokens excluding padding. + + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + qo_indptr: torch.Tensor + # The number of query/output heads + num_qo_heads: int + # The number of key/value heads + num_kv_heads: int + # The dimension of the attention heads + head_dim: int + # Block size of vllm + page_size: int + + wrapper: BatchPrefillWithPagedKVCacheWrapper = None + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") + + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + def plan(self): + global FLASHINFER_WORKSPACE_BUFFER, FLASHINFER_WRAPPER + block_table_bounds = ( + (self.seq_lens + self.page_size - 1) // self.page_size) + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + mask = (torch.arange(self.block_table.size(1), + dtype=self.block_table.dtype, + device=self.block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = self.block_table[mask] + + # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] + # Shape: [batch_size + 1] + paged_kv_indptr = torch.cat([ + torch.zeros(1, dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) + + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len = self.seq_lens % self.page_size + paged_kv_last_page_len = torch.where( + paged_kv_last_page_len == 0, self.page_size, paged_kv_last_page_len) + + if FLASHINFER_WORKSPACE_BUFFER is None: + FLASHINFER_WORKSPACE_BUFFER = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda") + FLASHINFER_WRAPPER = BatchPrefillWithPagedKVCacheWrapper( + FLASHINFER_WORKSPACE_BUFFER, "NHD") + + FLASHINFER_WRAPPER.plan( + self.qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + #sm_scale=self.sm_scale, + #window_left=self.window_left, + #logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashInferBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashInfer. " + f"Supported head sizes are: {support_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashInfer. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + output = FLASHINFER_WRAPPER.run( + query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output, + ) + + return output + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + num_sms: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window + and not use_alibi and np.all(query_lens == 1)) + if not use_flash_decoding: + # Use cascade attention. + return True + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = (num_reqs * num_kv_heads * + cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + fa_version: int, +) -> torch.Tensor: + assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window.") + + num_tokens = query.shape[0] + block_size = key_cache.shape[-3] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + + # Process shared prefix. + prefix_output, prefix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + seqused_k=prefix_kv_lens, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + return_softmax_lse=True, + fa_version=fa_version, + ) + + # Process suffix per query. + suffix_output, suffix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + seqused_k=suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_common_kv_blocks:], + softcap=logits_soft_cap, + return_softmax_lse=True, + fa_version=fa_version, + ) + + # Merge prefix and suffix outputs, and store the result in output. + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, + suffix_lse) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4d0ae9a205a1..c08b2062aa17 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,6 +26,8 @@ LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) +from vllm.v1.attention.backends.flashinfer import (FlashInferBackend, + FlashInferMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -427,7 +429,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: + ) -> Tuple[FlashInferMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -559,21 +561,36 @@ def _prepare_inputs( prefix_kv_lens = None suffix_kv_lens = None - attn_metadata = FlashAttentionMetadata( + # attn_metadata = FlashAttentionMetadata( + # num_actual_tokens=total_num_scheduled_tokens, + # max_query_len=max_num_scheduled_tokens, + # query_start_loc=query_start_loc, + # max_seq_len=max_seq_len, + # seq_lens=seq_lens, + # block_table=( + # self.input_batch.block_table.get_device_tensor()[:num_reqs]), + # slot_mapping=slot_mapping, + # use_cascade=use_cascade, + # common_prefix_len=common_prefix_len, + # cu_prefix_query_lens=cu_prefix_query_lens, + # prefix_kv_lens=prefix_kv_lens, + # suffix_kv_lens=suffix_kv_lens, + # ) + attn_metadata = FlashInferMetadata( num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, + qo_indptr=query_start_loc, + num_qo_heads=self.model_config.get_num_attention_heads(self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads(self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.cache_config.block_size, + data_type=self.kv_cache_dtype, + q_data_type=self.model_config.dtype, ) + attn_metadata.plan() use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -667,7 +684,7 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( + use_cascade = FlashInferBackend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, @@ -1379,7 +1396,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = FlashInferBackend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype From 0df828e9b50d9463083e1fb1b0cb9ac8971a8a28 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 00:13:01 +0000 Subject: [PATCH 02/24] unbreak flash_attn --- vllm/platforms/cuda.py | 9 +-- vllm/platforms/interface.py | 1 + vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 79 ++++++++++++++---------- 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 29ba311030ed..41f93cf9d470 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -154,10 +154,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_v1: - # logger.info("Using Flash Attention backend on V1 engine.") - # return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - logger.info("Using FlashInfer backend on V1 engine.") - return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + if selected_backend == _Backend.FLASHINFER_VLLM_V1: + logger.info("Using FlashInfer backend on V1 engine.") + return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d6dae2e526dc..ca8439ad49dc 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -34,6 +34,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() + FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 48b968d698ac..54d8ebd329c4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -141,7 +141,7 @@ def infer_global_hyperparameters( return global_params -class FlashInferState(AttentionState): +class FlashInferState: def __init__(self, runner): self.runner = runner diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c08b2062aa17..42871bb6f455 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -9,6 +9,7 @@ import torch.distributed import torch.nn as nn +from vllm.attention import get_attn_backend from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig @@ -235,6 +236,15 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -429,7 +439,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> Tuple[FlashInferMetadata, torch.Tensor]: + ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -561,36 +571,39 @@ def _prepare_inputs( prefix_kv_lens = None suffix_kv_lens = None - # attn_metadata = FlashAttentionMetadata( - # num_actual_tokens=total_num_scheduled_tokens, - # max_query_len=max_num_scheduled_tokens, - # query_start_loc=query_start_loc, - # max_seq_len=max_seq_len, - # seq_lens=seq_lens, - # block_table=( - # self.input_batch.block_table.get_device_tensor()[:num_reqs]), - # slot_mapping=slot_mapping, - # use_cascade=use_cascade, - # common_prefix_len=common_prefix_len, - # cu_prefix_query_lens=cu_prefix_query_lens, - # prefix_kv_lens=prefix_kv_lens, - # suffix_kv_lens=suffix_kv_lens, - # ) - attn_metadata = FlashInferMetadata( - num_actual_tokens=total_num_scheduled_tokens, - seq_lens=seq_lens, - block_table=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=slot_mapping, - qo_indptr=query_start_loc, - num_qo_heads=self.model_config.get_num_attention_heads(self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads(self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=self.cache_config.block_size, - data_type=self.kv_cache_dtype, - q_data_type=self.model_config.dtype, - ) - attn_metadata.plan() + attn_metadata_cls = self.attn_backend.get_metadata_cls() + if attn_metadata_cls is FlashAttentionMetadata: + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=( + self.input_batch.block_table.get_device_tensor()[:num_reqs]), + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + ) + elif attn_metadata_cls is FlashInferMetadata: + attn_metadata = FlashInferMetadata( + num_actual_tokens=total_num_scheduled_tokens, + seq_lens=seq_lens, + block_table=( + self.input_batch.block_table.get_device_tensor()[:num_reqs]), + slot_mapping=slot_mapping, + qo_indptr=query_start_loc, + num_qo_heads=self.model_config.get_num_attention_heads(self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads(self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.cache_config.block_size, + data_type=self.kv_cache_dtype, + q_data_type=self.model_config.dtype, + ) + attn_metadata.plan() use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -684,7 +697,7 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = FlashInferBackend.use_cascade_attention( + use_cascade = self.attn_backend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, @@ -1396,7 +1409,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = FlashInferBackend.get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype From 0f1eb93ed2482aceee80380f3147f29fa1d46334 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 03:18:30 +0000 Subject: [PATCH 03/24] make backend instantiable --- vllm/v1/attention/backends/flash_attn.py | 5 +- vllm/v1/attention/backends/flashinfer.py | 184 +++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 5 +- 3 files changed, 97 insertions(+), 97 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1922a3bf2724..4f302b2b7f93 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -20,10 +20,13 @@ logger = init_logger(__name__) -class FlashAttentionBackend(AttentionBackend): +class FlashAttentionBackend: accept_output_buffer: bool = True + def __init__(self, runner): + self.runner = runner + @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 54d8ebd329c4..d4f93c0f7775 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,12 +4,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type try: - from flashinfer import BatchPrefillWithPagedKVCacheWrapper + from flashinfer import (BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper) FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: # Avoid turning these types into variables during type checking if not TYPE_CHECKING: BatchPrefillWithPagedKVCacheWrapper = None + MultiLevelCascadeAttentionWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import numpy as np @@ -31,14 +33,86 @@ logger = init_logger(__name__) -FLASHINFER_WORKSPACE_BUFFER = None -FLASHINFER_WRAPPER = None - -class FlashInferBackend(AttentionBackend): +class FlashInferBackend: accept_output_buffer: bool = True + def __init__(self, runner): + self.runner = runner + self._workspace_buffer = None + self._wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_wrapper(self): + if self._wrapper is None: + self._wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._wrapper + + def begin_forward(self, attn_metadata: "FlashInferMetadata"): + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + block_table_bounds = ( + (attn_metadata.seq_lens + attn_metadata.page_size - 1) // attn_metadata.page_size) + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + mask = (torch.arange(attn_metadata.block_table.size(1), + dtype=attn_metadata.block_table.dtype, + device=attn_metadata.block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = attn_metadata.block_table[mask] + + # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] + # Shape: [batch_size + 1] + paged_kv_indptr = torch.cat([ + torch.zeros(1, dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) + + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len = attn_metadata.seq_lens % attn_metadata.page_size + paged_kv_last_page_len = torch.where( + paged_kv_last_page_len == 0, attn_metadata.page_size, paged_kv_last_page_len) + + attn_metadata.wrapper = self._get_wrapper() + attn_metadata.wrapper.plan( + attn_metadata.qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type) + @staticmethod def get_supported_head_sizes() -> List[int]: return [64, 128, 256] @@ -74,7 +148,6 @@ def use_cascade_attention(*args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) - @dataclass class PerLayerParameters: """ @@ -141,38 +214,6 @@ def infer_global_hyperparameters( return global_params -class FlashInferState: - - def __init__(self, runner): - self.runner = runner - self._workspace_buffer = None - self._wrapper = None - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = get_current_vllm_config() - - def _get_workspace_buffer(self): - if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.runner.device) - return self._workspace_buffer - - def _get_wrapper(self): - if self._wrapper is None: - self._wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") - return self._wrapper - - def begin_forward(self, attn_metadata: "FlashInferMetadata"): - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - attn_metadata.decode_wrapper = self._get_decode_wrapper() - attn_metadata.begin_forward() - - @dataclass class FlashInferMetadata: @@ -235,61 +276,6 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f" received {self.head_dim}.") - def plan(self): - global FLASHINFER_WORKSPACE_BUFFER, FLASHINFER_WRAPPER - block_table_bounds = ( - (self.seq_lens + self.page_size - 1) // self.page_size) - - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - mask = (torch.arange(self.block_table.size(1), - dtype=self.block_table.dtype, - device=self.block_table.device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - paged_kv_indices = self.block_table[mask] - - # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] - # Shape: [batch_size + 1] - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, - device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) - - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len = self.seq_lens % self.page_size - paged_kv_last_page_len = torch.where( - paged_kv_last_page_len == 0, self.page_size, paged_kv_last_page_len) - - if FLASHINFER_WORKSPACE_BUFFER is None: - FLASHINFER_WORKSPACE_BUFFER = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device="cuda") - FLASHINFER_WRAPPER = BatchPrefillWithPagedKVCacheWrapper( - FLASHINFER_WORKSPACE_BUFFER, "NHD") - - FLASHINFER_WRAPPER.plan( - self.qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - #sm_scale=self.sm_scale, - #window_left=self.window_left, - #logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.data_type) - class FlashInferImpl(AttentionImpl): @@ -393,7 +379,17 @@ def forward( layer._v_scale, ) - output = FLASHINFER_WRAPPER.run( + window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) + + assert attn_metadata.wrapper is not None + assert attn_metadata.wrapper._causal + assert attn_metadata.wrapper._window_left == window_left + assert attn_metadata.wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0) + assert attn_metadata.wrapper._sm_scale == self.scale + + output = attn_metadata.wrapper.run( query, kv_cache, k_scale=layer._k_scale_float, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42871bb6f455..c82b5cd511da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -236,7 +236,7 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - self.attn_backend = get_attn_backend( + attn_backend_cls = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, @@ -244,6 +244,7 @@ def __init__( self.model_config.is_attention_free, use_mla=self.model_config.use_mla, ) + self.attn_backend = attn_backend_cls(self) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -603,7 +604,7 @@ def _prepare_inputs( data_type=self.kv_cache_dtype, q_data_type=self.model_config.dtype, ) - attn_metadata.plan() + self.attn_backend.begin_forward(attn_metadata) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 From ff0e3630b674364b641b975474df0ec807b6b6cd Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 16:52:42 +0000 Subject: [PATCH 04/24] cascade attention minimally working --- vllm/v1/attention/backends/flashinfer.py | 173 ++++++++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 5 + 2 files changed, 144 insertions(+), 34 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index d4f93c0f7775..67f2dc5b22bd 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -42,6 +42,7 @@ def __init__(self, runner): self.runner = runner self._workspace_buffer = None self._wrapper = None + self._cascade_wrapper = None # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None @@ -62,11 +63,13 @@ def _get_wrapper(self): self._get_workspace_buffer(), "NHD") return self._wrapper - def begin_forward(self, attn_metadata: "FlashInferMetadata"): - if self.global_hyperparameters is None: - self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) + def _get_cascade_wrapper(self): + if self._cascade_wrapper is None: + self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._cascade_wrapper + def _get_normal_attn_args(self, attn_metadata: "FlashInferMetadata"): block_table_bounds = ( (attn_metadata.seq_lens + attn_metadata.page_size - 1) // attn_metadata.page_size) @@ -95,13 +98,74 @@ def begin_forward(self, attn_metadata: "FlashInferMetadata"): paged_kv_last_page_len = attn_metadata.seq_lens % attn_metadata.page_size paged_kv_last_page_len = torch.where( paged_kv_last_page_len == 0, attn_metadata.page_size, paged_kv_last_page_len) + + return attn_metadata.qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len + + def _get_cascade_attn_args(self, attn_metadata: "FlashInferMetadata"): + seq_lens = attn_metadata.seq_lens + page_size = attn_metadata.page_size + block_table = attn_metadata.block_table + + num_common_kv_blocks = attn_metadata.common_prefix_len // page_size + shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device=block_table.device) + shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, + device=block_table.device) + block_table = block_table[:, num_common_kv_blocks:] + + qo_indptr_arr = [attn_metadata.cu_prefix_query_lens, + attn_metadata.qo_indptr] + + block_table_bounds = (seq_lens + page_size - 1) // page_size + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] + # Shape: [batch_size + 1] + paged_kv_indptr = torch.cat([ + torch.zeros(1, dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) + + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + + return ( + qo_indptr_arr, + [shared_kv_page_indptr, paged_kv_indptr], + [shared_kv_page_indices, paged_kv_indices], + [shared_kv_last_page_len, paged_kv_last_page_len], + ) + + def begin_forward(self, attn_metadata: "FlashInferMetadata"): + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + if not attn_metadata.use_cascade: + attn_metadata.wrapper = self._get_wrapper() + attn_args = self._get_normal_attn_args(attn_metadata) + else: + attn_metadata.wrapper = self._get_cascade_wrapper() + attn_args = self._get_cascade_attn_args(attn_metadata) - attn_metadata.wrapper = self._get_wrapper() attn_metadata.wrapper.plan( - attn_metadata.qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, + *attn_args, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -111,7 +175,8 @@ def begin_forward(self, attn_metadata: "FlashInferMetadata"): window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters.logits_soft_cap, q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type) + #kv_data_type=attn_metadata.data_type, + ) @staticmethod def get_supported_head_sizes() -> List[int]: @@ -144,7 +209,6 @@ def get_kv_cache_shape( @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: - return False return use_cascade_attention(*args, **kwargs) @@ -236,6 +300,13 @@ class FlashInferMetadata: # Block size of vllm page_size: int + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + wrapper: BatchPrefillWithPagedKVCacheWrapper = None # For logging. @@ -251,20 +322,20 @@ class FlashInferMetadata: # The FlashInfer backend currently supports only models in which all layers # share the same following hyperparameters: - # The left (inclusive) window size for the attention window, when - # set to `-1`, the window size will be set to the full length of - # the sequence. Defaults to `-1`. - window_left: int = -1 - # The attention logits soft capping value (used in Gemini, Grok and - # Gemma-2, etc.), if not provided, will be set to `0`. If greater - # than 0, the logits will be capped according to formula: - # $$\texttt{logits\_soft\_cap} \times - # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, - # where $x$ is the input logits. - logits_soft_cap: Optional[float] = None - # The scale used in softmax, if not provided, will be set to - # `1.0 / sqrt(head_dim)`. - sm_scale: Optional[float] = None + # # The left (inclusive) window size for the attention window, when + # # set to `-1`, the window size will be set to the full length of + # # the sequence. Defaults to `-1`. + # window_left: int = -1 + # # The attention logits soft capping value (used in Gemini, Grok and + # # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # # than 0, the logits will be capped according to formula: + # # $$\texttt{logits\_soft\_cap} \times + # # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # # where $x$ is the input logits. + # logits_soft_cap: Optional[float] = None + # # The scale used in softmax, if not provided, will be set to + # # `1.0 / sqrt(head_dim)`. + # sm_scale: Optional[float] = None def __post_init__(self): # Refer to @@ -383,18 +454,52 @@ def forward( if self.sliding_window is not None else -1) assert attn_metadata.wrapper is not None - assert attn_metadata.wrapper._causal - assert attn_metadata.wrapper._window_left == window_left - assert attn_metadata.wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) - assert attn_metadata.wrapper._sm_scale == self.scale + #assert attn_metadata.wrapper._causal + #assert attn_metadata.wrapper._window_left == window_left + #assert attn_metadata.wrapper._logits_soft_cap == ( + # self.logits_soft_cap or 0.0) + #assert attn_metadata.wrapper._sm_scale == self.scale + + if not attn_metadata.use_cascade: + # Regular attention (common case). + output = attn_metadata.wrapper.run( + query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output, + ) + return output - output = attn_metadata.wrapper.run( + # Cascade attention (rare case). + print("CASCADE") + out = attn_metadata.wrapper.run( query, kv_cache, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output, + #k_scale=layer._k_scale_float, + #v_scale=layer._v_scale_float, + #out=output, + ) + output.copy_(out) + return output + cascade_attention( + output[:attn_metadata.num_actual_tokens], + query[:attn_metadata.num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + fa_version=self.vllm_flash_attn_version, ) return output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c82b5cd511da..d324c376743c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -603,6 +603,11 @@ def _prepare_inputs( page_size=self.cache_config.block_size, data_type=self.kv_cache_dtype, q_data_type=self.model_config.dtype, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, ) self.attn_backend.begin_forward(attn_metadata) From 1e1e0ae2a933bb546a4002a70cfd1c76926e9136 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 17:26:28 +0000 Subject: [PATCH 05/24] cleanup --- vllm/v1/attention/backends/flashinfer.py | 291 ++++++----------------- 1 file changed, 73 insertions(+), 218 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 67f2dc5b22bd..5a2166d0461d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,19 +17,13 @@ import numpy as np import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionState, AttentionType) +from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, + AttentionMetadata, AttentionType) from vllm.attention.layer import Attention -from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv -if current_platform.is_cuda(): - from vllm.vllm_flash_attn import flash_attn_varlen_func logger = init_logger(__name__) @@ -41,8 +35,8 @@ class FlashInferBackend: def __init__(self, runner): self.runner = runner self._workspace_buffer = None - self._wrapper = None - self._cascade_wrapper = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None @@ -57,11 +51,11 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer - def _get_wrapper(self): - if self._wrapper is None: - self._wrapper = BatchPrefillWithPagedKVCacheWrapper( + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD") - return self._wrapper + return self._prefill_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: @@ -69,54 +63,22 @@ def _get_cascade_wrapper(self): 2, self._get_workspace_buffer(), "NHD") return self._cascade_wrapper - def _get_normal_attn_args(self, attn_metadata: "FlashInferMetadata"): - block_table_bounds = ( - (attn_metadata.seq_lens + attn_metadata.page_size - 1) // attn_metadata.page_size) - - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - mask = (torch.arange(attn_metadata.block_table.size(1), - dtype=attn_metadata.block_table.dtype, - device=attn_metadata.block_table.device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - paged_kv_indices = attn_metadata.block_table[mask] - - # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] - # Shape: [batch_size + 1] - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, - device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) - - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len = attn_metadata.seq_lens % attn_metadata.page_size - paged_kv_last_page_len = torch.where( - paged_kv_last_page_len == 0, attn_metadata.page_size, paged_kv_last_page_len) - - return attn_metadata.qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len - - def _get_cascade_attn_args(self, attn_metadata: "FlashInferMetadata"): + def begin_forward(self, attn_metadata: "FlashInferMetadata"): seq_lens = attn_metadata.seq_lens page_size = attn_metadata.page_size block_table = attn_metadata.block_table - num_common_kv_blocks = attn_metadata.common_prefix_len // page_size - shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device=block_table.device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, - device=block_table.device) - block_table = block_table[:, num_common_kv_blocks:] - - qo_indptr_arr = [attn_metadata.cu_prefix_query_lens, - attn_metadata.qo_indptr] + if attn_metadata.use_cascade: + # Grab the blocks of the shared prefix from the first request. + num_common_kv_blocks = attn_metadata.common_prefix_len // page_size + shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device=block_table.device) + shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, + device=block_table.device) + # Remove the blocks of the shared prefix from all requests. + block_table = block_table[:, num_common_kv_blocks:] block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -145,38 +107,46 @@ def _get_cascade_attn_args(self, attn_metadata: "FlashInferMetadata"): paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - return ( - qo_indptr_arr, - [shared_kv_page_indptr, paged_kv_indptr], - [shared_kv_page_indices, paged_kv_indices], - [shared_kv_last_page_len, paged_kv_last_page_len], - ) - - def begin_forward(self, attn_metadata: "FlashInferMetadata"): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config)) - if not attn_metadata.use_cascade: - attn_metadata.wrapper = self._get_wrapper() - attn_args = self._get_normal_attn_args(attn_metadata) + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [attn_metadata.cu_prefix_query_lens, attn_metadata.qo_indptr], + [shared_kv_page_indptr, paged_kv_indptr], + [shared_kv_page_indices, paged_kv_indices], + [shared_kv_last_page_len, paged_kv_last_page_len], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + #kv_data_type=attn_metadata.data_type, + ) else: - attn_metadata.wrapper = self._get_cascade_wrapper() - attn_args = self._get_cascade_attn_args(attn_metadata) - - attn_metadata.wrapper.plan( - *attn_args, - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - #kv_data_type=attn_metadata.data_type, - ) + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + attn_metadata.prefill_wrapper.plan( + attn_metadata.qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) @staticmethod def get_supported_head_sizes() -> List[int]: @@ -194,10 +164,6 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata - @staticmethod - def get_state_cls() -> Type["FlashInferState"]: - return FlashInferState - @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -207,8 +173,12 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: + #@staticmethod + def use_cascade_attention(self, *args, **kwargs) -> bool: + if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False return use_cascade_attention(*args, **kwargs) @@ -307,7 +277,8 @@ class FlashInferMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] - wrapper: BatchPrefillWithPagedKVCacheWrapper = None + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper = None # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -319,24 +290,6 @@ class FlashInferMetadata: # FlashInfer 0.2 encourages passing host tensors device: torch.device = torch.device("cpu") - # The FlashInfer backend currently supports only models in which all layers - # share the same following hyperparameters: - - # # The left (inclusive) window size for the attention window, when - # # set to `-1`, the window size will be set to the full length of - # # the sequence. Defaults to `-1`. - # window_left: int = -1 - # # The attention logits soft capping value (used in Gemini, Grok and - # # Gemma-2, etc.), if not provided, will be set to `0`. If greater - # # than 0, the logits will be capped according to formula: - # # $$\texttt{logits\_soft\_cap} \times - # # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, - # # where $x$ is the input logits. - # logits_soft_cap: Optional[float] = None - # # The scale used in softmax, if not provided, will be set to - # # `1.0 / sqrt(head_dim)`. - # sm_scale: Optional[float] = None - def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 @@ -453,16 +406,17 @@ def forward( window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) - assert attn_metadata.wrapper is not None - #assert attn_metadata.wrapper._causal - #assert attn_metadata.wrapper._window_left == window_left - #assert attn_metadata.wrapper._logits_soft_cap == ( - # self.logits_soft_cap or 0.0) - #assert attn_metadata.wrapper._sm_scale == self.scale + if not attn_metadata.use_cascade: # Regular attention (common case). - output = attn_metadata.wrapper.run( + assert attn_metadata.prefill_wrapper is not None + assert attn_metadata.prefill_wrapper._causal + assert attn_metadata.prefill_wrapper._window_left == window_left + assert attn_metadata.prefill_wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0) + assert attn_metadata.prefill_wrapper._sm_scale == self.scale + output = attn_metadata.prefill_wrapper.run( query, kv_cache, k_scale=layer._k_scale_float, @@ -472,36 +426,8 @@ def forward( return output # Cascade attention (rare case). - print("CASCADE") - out = attn_metadata.wrapper.run( - query, - kv_cache, - #k_scale=layer._k_scale_float, - #v_scale=layer._v_scale_float, - #out=output, - ) - output.copy_(out) - return output - cascade_attention( - output[:attn_metadata.num_actual_tokens], - query[:attn_metadata.num_actual_tokens], - key_cache, - value_cache, - cu_query_lens=attn_metadata.query_start_loc, - max_query_len=attn_metadata.max_query_len, - cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, - prefix_kv_lens=attn_metadata.prefix_kv_lens, - suffix_kv_lens=attn_metadata.suffix_kv_lens, - max_kv_len=attn_metadata.max_seq_len, - softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window, - logits_soft_cap=self.logits_soft_cap, - block_table=attn_metadata.block_table, - common_prefix_len=attn_metadata.common_prefix_len, - fa_version=self.vllm_flash_attn_version, - ) - + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output @@ -571,74 +497,3 @@ def use_cascade_attention( # Use cascade attention if it is faster than FlashDecoding. return cascade_time < flash_decoding_time - - -def cascade_attention( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cu_query_lens: torch.Tensor, - max_query_len: int, - cu_prefix_query_lens: torch.Tensor, - prefix_kv_lens: torch.Tensor, - suffix_kv_lens: torch.Tensor, - max_kv_len: int, - softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Tuple[int, int], - logits_soft_cap: float, - block_table: torch.Tensor, - common_prefix_len: int, - fa_version: int, -) -> torch.Tensor: - assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") - # TODO: Support sliding window. - assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window.") - - num_tokens = query.shape[0] - block_size = key_cache.shape[-3] - assert common_prefix_len % block_size == 0 - num_common_kv_blocks = common_prefix_len // block_size - assert num_common_kv_blocks > 0 - - # Process shared prefix. - prefix_output, prefix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_prefix_query_lens, - seqused_k=prefix_kv_lens, - max_seqlen_q=num_tokens, - max_seqlen_k=common_prefix_len, - softmax_scale=softmax_scale, - causal=False, - window_size=sliding_window, - block_table=block_table[:1], - softcap=logits_soft_cap, - return_softmax_lse=True, - fa_version=fa_version, - ) - - # Process suffix per query. - suffix_output, suffix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - seqused_k=suffix_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len - common_prefix_len, - softmax_scale=softmax_scale, - causal=True, - window_size=sliding_window, - block_table=block_table[:, num_common_kv_blocks:], - softcap=logits_soft_cap, - return_softmax_lse=True, - fa_version=fa_version, - ) - - # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) \ No newline at end of file From 33c980efee9363ab36bcedb50a07153c04ea9547 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 17:28:14 +0000 Subject: [PATCH 06/24] small --- vllm/v1/attention/backends/flashinfer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5a2166d0461d..984cf823fa0c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -282,7 +282,7 @@ class FlashInferMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - + # The data type of the paged kv cache data_type: torch.dtype = None # The data type of the query @@ -406,8 +406,6 @@ def forward( window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) - - if not attn_metadata.use_cascade: # Regular attention (common case). assert attn_metadata.prefill_wrapper is not None From be8e19385b58c73359ffb9800bef253a98af8187 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 28 Feb 2025 21:41:16 +0000 Subject: [PATCH 07/24] cleanup --- vllm/v1/attention/backends/flashinfer.py | 191 +++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 24 +-- 2 files changed, 138 insertions(+), 77 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 984cf823fa0c..c6f9714b5efe 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -24,6 +24,10 @@ from vllm.logger import init_logger from vllm.utils import cdiv +if TYPE_CHECKING: + from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -64,60 +68,19 @@ def _get_cascade_wrapper(self): return self._cascade_wrapper def begin_forward(self, attn_metadata: "FlashInferMetadata"): - seq_lens = attn_metadata.seq_lens - page_size = attn_metadata.page_size - block_table = attn_metadata.block_table - - if attn_metadata.use_cascade: - # Grab the blocks of the shared prefix from the first request. - num_common_kv_blocks = attn_metadata.common_prefix_len // page_size - shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device=block_table.device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, - device=block_table.device) - # Remove the blocks of the shared prefix from all requests. - block_table = block_table[:, num_common_kv_blocks:] - - block_table_bounds = (seq_lens + page_size - 1) // page_size - - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] - - # paged_kv_indptr is used to index into paged_kv_indices: [0, 3, 6, 8] - # Shape: [batch_size + 1] - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, - device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) - - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) - if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config)) - if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( - [attn_metadata.cu_prefix_query_lens, attn_metadata.qo_indptr], - [shared_kv_page_indptr, paged_kv_indptr], - [shared_kv_page_indices, paged_kv_indices], - [shared_kv_last_page_len, paged_kv_last_page_len], + [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], + [attn_metadata.shared_kv_page_indptr, + attn_metadata.paged_kv_indptr], + [attn_metadata.shared_kv_page_indices, + attn_metadata.paged_kv_indices], + [attn_metadata.shared_kv_last_page_len, + attn_metadata.paged_kv_last_page_len], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -127,15 +90,14 @@ def begin_forward(self, attn_metadata: "FlashInferMetadata"): window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters.logits_soft_cap, q_data_type=attn_metadata.q_data_type, - #kv_data_type=attn_metadata.data_type, ) else: attn_metadata.prefill_wrapper = self._get_prefill_wrapper() attn_metadata.prefill_wrapper.plan( attn_metadata.qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -164,6 +126,10 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -253,14 +219,25 @@ class FlashInferMetadata: num_actual_tokens: int # Number of tokens excluding padding. - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. qo_indptr: torch.Tensor + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor # The number of query/output heads num_qo_heads: int # The number of key/value heads @@ -269,13 +246,19 @@ class FlashInferMetadata: head_dim: int # Block size of vllm page_size: int + # The data type of the paged kv cache + data_type: torch.dtype + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor # For cascade attention. use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + shared_qo_indptr: Optional[torch.Tensor] = None + shared_kv_page_indptr: Optional[torch.Tensor] = None + shared_kv_page_indices: Optional[torch.Tensor] = None + shared_kv_last_page_len: Optional[torch.Tensor] = None prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper = None cascade_wrapper: MultiLevelCascadeAttentionWrapper = None @@ -283,13 +266,6 @@ class FlashInferMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - # The data type of the paged kv cache - data_type: torch.dtype = None - # The data type of the query - q_data_type: torch.dtype = None - # FlashInfer 0.2 encourages passing host tensors - device: torch.device = torch.device("cpu") - def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 @@ -301,6 +277,87 @@ def __post_init__(self): f" received {self.head_dim}.") +class FlashInferMetadataBuilder: + + def __init__(self, runner: "GPUModelRunner"): + self.runner = runner + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput"): + pass + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + page_size = self.runner.block_size + device = self.runner.device + qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + self.runner.device, non_blocking=True) + seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, + non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + + use_cascade = common_prefix_len > 0 + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + num_common_kv_blocks = common_prefix_len // page_size + shared_qo_indptr = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=device) + shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device=device) + shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, + device=device) + # Remove the blocks of the shared prefix from all requests. + block_table = block_table[:, num_common_kv_blocks:] + else: + shared_qo_indptr = None + shared_kv_page_indptr = None + shared_kv_page_indices = None + shared_kv_last_page_len = None + + block_table_bounds = (seq_lens + page_size - 1) // page_size + + mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + + attn_metadata = FlashInferMetadata( + num_actual_tokens=num_actual_tokens, + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.runner.num_query_heads, + num_kv_heads=self.runner.num_kv_heads, + head_dim=self.runner.head_size, + page_size=page_size, + data_type=self.runner.kv_cache_dtype, + q_data_type=self.runner.dtype, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + shared_qo_indptr=shared_qo_indptr, + shared_kv_page_indptr=shared_kv_page_indptr, + shared_kv_page_indices=shared_kv_page_indices, + shared_kv_last_page_len=shared_kv_last_page_len, + ) + return attn_metadata + + class FlashInferImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1992c5edf61f..f9006c17752b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,6 +26,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flashinfer import FlashInferMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -254,15 +255,9 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - attn_backend_cls = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - self.attn_backend = attn_backend_cls(self) + # Instantiate the backend class. + # FIXME: clean up after deciding if the backend should be instantiable. + self.attn_backend = self.attn_backend(self) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -578,6 +573,9 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, ) + if isinstance(attn_metadata, FlashInferMetadata): + # FIXME: abstract this away so it's not flashinfer-specific + self.attn_backend.begin_forward(attn_metadata) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -590,7 +588,13 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + if isinstance(attn_metadata, FlashInferMetadata): + # FIXME: abstract this away so it's not flashinfer-specific. + # This code should not rely on knowing the internals of the + # attention metadata. + logits_indices = attn_metadata.qo_indptr[1:] - 1 + else: + logits_indices = attn_metadata.query_start_loc[1:] - 1 # Hot-Swap lora model if self.lora_config: From ba095b738a7d2fd191fe745748ffbc95a463fdaf Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 00:19:19 +0000 Subject: [PATCH 08/24] make backend stateless --- vllm/v1/attention/backends/flash_attn.py | 5 +- vllm/v1/attention/backends/flashinfer.py | 176 ++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 14 +- 3 files changed, 101 insertions(+), 94 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f489eeb5b268..353bf46d503e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -25,13 +25,10 @@ logger = init_logger(__name__) -class FlashAttentionBackend: +class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - def __init__(self, runner): - self.runner = runner - @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c6f9714b5efe..c94d5a1b20db 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,7 +17,7 @@ import numpy as np import torch -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_current_vllm_config @@ -32,84 +32,10 @@ logger = init_logger(__name__) -class FlashInferBackend: +class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - def __init__(self, runner): - self.runner = runner - self._workspace_buffer = None - self._prefill_wrapper = None # Wrapper for prefill/append - self._cascade_wrapper = None # Wrapper for cascade attention - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = get_current_vllm_config() - - def _get_workspace_buffer(self): - if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.runner.device) - return self._workspace_buffer - - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") - return self._prefill_wrapper - - def _get_cascade_wrapper(self): - if self._cascade_wrapper is None: - self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), "NHD") - return self._cascade_wrapper - - def begin_forward(self, attn_metadata: "FlashInferMetadata"): - if self.global_hyperparameters is None: - self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) - if attn_metadata.use_cascade: - attn_metadata.cascade_wrapper = self._get_cascade_wrapper() - attn_metadata.cascade_wrapper.plan( - [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], - [attn_metadata.shared_kv_page_indptr, - attn_metadata.paged_kv_indptr], - [attn_metadata.shared_kv_page_indices, - attn_metadata.paged_kv_indices], - [attn_metadata.shared_kv_last_page_len, - attn_metadata.paged_kv_last_page_len], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - ) - else: - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - attn_metadata.prefill_wrapper.plan( - attn_metadata.qo_indptr, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len, - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, - ) - @staticmethod def get_supported_head_sizes() -> List[int]: return [64, 128, 256] @@ -126,6 +52,10 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + @staticmethod def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder @@ -139,12 +69,12 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) - #@staticmethod - def use_cascade_attention(self, *args, **kwargs) -> bool: - if self.runner.kv_cache_dtype != self.runner.model_config.dtype: - # TODO: The cascade wrapper currently does not support setting - # kv cache dtype to something different from query dtype. - return False + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + #if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + # # TODO: The cascade wrapper currently does not support setting + # # kv cache dtype to something different from query dtype. + # return False return use_cascade_attention(*args, **kwargs) @@ -214,6 +144,83 @@ def infer_global_hyperparameters( return global_params +class FlashInferState: + + def __init__(self, runner): + self.runner = runner + self._workspace_buffer = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._cascade_wrapper = None # Wrapper for cascade attention + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_cascade_wrapper(self): + if self._cascade_wrapper is None: + self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._cascade_wrapper + + def begin_forward(self, attn_metadata: "FlashInferMetadata"): + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], + [attn_metadata.shared_kv_page_indptr, + attn_metadata.paged_kv_indptr], + [attn_metadata.shared_kv_page_indices, + attn_metadata.paged_kv_indices], + [attn_metadata.shared_kv_last_page_len, + attn_metadata.paged_kv_last_page_len], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + ) + else: + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + attn_metadata.prefill_wrapper.plan( + attn_metadata.qo_indptr, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len, + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + @dataclass class FlashInferMetadata: @@ -281,6 +288,8 @@ class FlashInferMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): self.runner = runner + self.state = runner.attn_state + assert isinstance(self.state, FlashInferState) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"): @@ -355,6 +364,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, ) + self.state.begin_forward(attn_metadata) return attn_metadata @@ -410,7 +420,7 @@ def __init__( def forward( self, - layer: AttentionLayer, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f9006c17752b..a8130681e8a0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -111,6 +111,13 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") + try: + attn_state_cls = self.attn_backend.get_state_cls() + except NotImplementedError: + self.attn_state = None + else: + self.attn_state = attn_state_cls(weakref.proxy(self)) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) @@ -255,10 +262,6 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - # Instantiate the backend class. - # FIXME: clean up after deciding if the backend should be instantiable. - self.attn_backend = self.attn_backend(self) - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -573,9 +576,6 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, ) - if isinstance(attn_metadata, FlashInferMetadata): - # FIXME: abstract this away so it's not flashinfer-specific - self.attn_backend.begin_forward(attn_metadata) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 From a59fdba9452fcdd79cde8da2b3fb58f3481c3f53 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 00:38:31 +0000 Subject: [PATCH 09/24] clean --- vllm/platforms/cuda.py | 2 +- vllm/platforms/interface.py | 1 - vllm/v1/attention/backends/flashinfer.py | 73 +----------------------- 3 files changed, 3 insertions(+), 73 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2e26fb9afbb3..425a188699df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -194,7 +194,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - if selected_backend == _Backend.FLASHINFER_VLLM_V1: + if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend on V1 engine.") return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" logger.info("Using Flash Attention backend on V1 engine.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d5e58e02bff6..d81a66e4bcb1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -34,7 +34,6 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() - FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 FLASHMLA = enum.auto() # Supported by V1 HPU_ATTN = enum.auto() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c94d5a1b20db..d5595450112f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Attention layer with FlashAttention.""" +"""Attention layer with FlashInfer.""" from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -14,7 +14,6 @@ MultiLevelCascadeAttentionWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 -import numpy as np import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -22,7 +21,7 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger -from vllm.utils import cdiv +from vllm.v1.attention.backends.flash_attn import use_cascade_attention if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput @@ -494,71 +493,3 @@ def forward( assert attn_metadata.cascade_wrapper is not None output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output - - -def use_cascade_attention( - common_prefix_len: int, - query_lens: np.ndarray, - num_query_heads: int, - num_kv_heads: int, - use_alibi: bool, - use_sliding_window: bool, - num_sms: int, -) -> bool: - """Decide whether to use cascade attention. - - This function 1) checks whether cascade attention is supported with the - given configuration, and 2) heuristically decides whether using cascade - attention can improve performance. - """ - # Too short common prefix. Probably not worth using cascade attention. - # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. - # NOTE(woosuk): This is the common case. We should return False as soon as - # possible to avoid any unnecessary computation. - if common_prefix_len < 256: - return False - # Cascade attention is currently not supported with these variants. - if use_alibi or use_sliding_window: - return False - # Too few queries. Probably not worth using cascade attention. - # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. - num_reqs = len(query_lens) - if num_reqs < 8: - return False - - # Heuristics to decide whether using cascade attention is beneficial. - # 1. When FlashDecoding is not used for normal attention, cascade attention - # is likely to be faster since it saves memory bandwidth. - num_queries_per_kv = num_query_heads // num_kv_heads - # The criteria for using FlashDecoding can be found in the following link: - # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) - if not use_flash_decoding: - # Use cascade attention. - return True - - # 2. When FlashDecoding is used for normal attention, it is not clear - # whether cascade attention is beneficial, because FlashDecoding can - # launch more CTAs than cascade attention. - # We use a simple performance model to compare the two methods. - # NOTE(woosuk): The performance model is very rough and may not be - # accurate. - num_tokens = num_reqs - # NOTE(woosuk): These are default tile sizes. flash-attn might use - # different tile sizes (e.g., 64 or 256) depending on the configuration. - q_tile_size = 128 - kv_tile_size = 128 - num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) - - cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) - cascade_waves = cdiv(cascade_ctas, num_sms) - cascade_time = cascade_waves * num_prefix_tiles - - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) - flash_decoding_ctas *= num_prefix_tiles - flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) - - # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time From 5e4c87c3b8e599e95cb888ef5eddc695a946814e Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 00:52:51 +0000 Subject: [PATCH 10/24] clean --- vllm/v1/attention/backends/flashinfer.py | 37 ++++++++++-------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index d5595450112f..cb8df7918842 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,20 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with FlashInfer.""" +from __future__ import annotations + from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -try: - from flashinfer import (BatchPrefillWithPagedKVCacheWrapper, - MultiLevelCascadeAttentionWrapper) - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - # Avoid turning these types into variables during type checking - if not TYPE_CHECKING: - BatchPrefillWithPagedKVCacheWrapper = None - MultiLevelCascadeAttentionWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch +from flashinfer import (BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper) + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -41,22 +36,22 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "FLASHINFER_VLLM_V1" + return "FLASHINFER" @staticmethod - def get_impl_cls() -> Type["FlashInferImpl"]: + def get_impl_cls() -> FlashInferImpl: return FlashInferImpl @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> AttentionMetadata: return FlashInferMetadata @staticmethod - def get_state_cls() -> Type["FlashInferState"]: + def get_state_cls() -> FlashInferState: return FlashInferState @staticmethod - def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + def get_builder_cls() -> FlashInferMetadataBuilder: return FlashInferMetadataBuilder @staticmethod @@ -176,7 +171,7 @@ def _get_cascade_wrapper(self): 2, self._get_workspace_buffer(), "NHD") return self._cascade_wrapper - def begin_forward(self, attn_metadata: "FlashInferMetadata"): + def begin_forward(self, attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config)) @@ -285,13 +280,13 @@ def __post_init__(self): class FlashInferMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, runner: GPUModelRunner): self.runner = runner self.state = runner.attn_state assert isinstance(self.state, FlashInferState) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput"): + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput): pass def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, From 53f8b81153d52119b93646914ac9de692572de4e Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 01:10:16 +0000 Subject: [PATCH 11/24] remove state class --- vllm/v1/attention/backends/flashinfer.py | 154 +++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 7 -- 2 files changed, 71 insertions(+), 90 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb8df7918842..b0a45c96f1c6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -46,10 +46,6 @@ def get_impl_cls() -> FlashInferImpl: def get_metadata_cls() -> AttentionMetadata: return FlashInferMetadata - @staticmethod - def get_state_cls() -> FlashInferState: - return FlashInferState - @staticmethod def get_builder_cls() -> FlashInferMetadataBuilder: return FlashInferMetadataBuilder @@ -138,9 +134,72 @@ def infer_global_hyperparameters( return global_params -class FlashInferState: +@dataclass +class FlashInferMetadata: + + num_actual_tokens: int # Number of tokens excluding padding. + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + qo_indptr: torch.Tensor + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor + # The number of query/output heads + num_qo_heads: int + # The number of key/value heads + num_kv_heads: int + # The dimension of the attention heads + head_dim: int + # Block size of vllm + page_size: int + # The data type of the paged kv cache + data_type: torch.dtype + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + shared_qo_indptr: Optional[torch.Tensor] = None + shared_kv_page_indptr: Optional[torch.Tensor] = None + shared_kv_page_indices: Optional[torch.Tensor] = None + shared_kv_last_page_len: Optional[torch.Tensor] = None + + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper = None + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + +class FlashInferMetadataBuilder: - def __init__(self, runner): + def __init__(self, runner: GPUModelRunner): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -151,6 +210,10 @@ def __init__(self, runner): self.vllm_config = get_current_vllm_config() + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput): + pass + def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.empty( @@ -171,7 +234,7 @@ def _get_cascade_wrapper(self): 2, self._get_workspace_buffer(), "NHD") return self._cascade_wrapper - def begin_forward(self, attn_metadata: FlashInferMetadata): + def _plan(self, attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config)) @@ -214,81 +277,6 @@ def begin_forward(self, attn_metadata: FlashInferMetadata): kv_data_type=attn_metadata.data_type, ) - -@dataclass -class FlashInferMetadata: - - num_actual_tokens: int # Number of tokens excluding padding. - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - qo_indptr: torch.Tensor - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: torch.Tensor - # The page indices of the paged kv cache - paged_kv_indices: torch.Tensor - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: torch.Tensor - # The number of query/output heads - num_qo_heads: int - # The number of key/value heads - num_kv_heads: int - # The dimension of the attention heads - head_dim: int - # Block size of vllm - page_size: int - # The data type of the paged kv cache - data_type: torch.dtype - # The data type of the query - q_data_type: torch.dtype - - slot_mapping: torch.Tensor - - # For cascade attention. - use_cascade: bool - shared_qo_indptr: Optional[torch.Tensor] = None - shared_kv_page_indptr: Optional[torch.Tensor] = None - shared_kv_page_indices: Optional[torch.Tensor] = None - shared_kv_last_page_len: Optional[torch.Tensor] = None - - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper = None - cascade_wrapper: MultiLevelCascadeAttentionWrapper = None - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - - def __post_init__(self): - # Refer to - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - supported_head_sizes = FlashInferBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - -class FlashInferMetadataBuilder: - - def __init__(self, runner: GPUModelRunner): - self.runner = runner - self.state = runner.attn_state - assert isinstance(self.state, FlashInferState) - - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput): - pass - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): page_size = self.runner.block_size @@ -358,7 +346,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, ) - self.state.begin_forward(attn_metadata) + self._plan(attn_metadata) return attn_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a8130681e8a0..67380898efbb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -111,13 +111,6 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") - try: - attn_state_cls = self.attn_backend.get_state_cls() - except NotImplementedError: - self.attn_state = None - else: - self.attn_state = attn_state_cls(weakref.proxy(self)) - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) From 8be9a473f6c21a9dcf1a84aedc73f419bac11abf Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 01:15:58 +0000 Subject: [PATCH 12/24] clean --- vllm/v1/attention/backends/flashinfer.py | 5 +++++ vllm/v1/worker/gpu_model_runner.py | 9 +-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b0a45c96f1c6..4a2d4e039f6e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -186,6 +186,11 @@ class FlashInferMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. + @property + def query_start_loc(self): + # The GPUModelRunner expects to be able to access this property. + return self.qo_indptr + def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67380898efbb..2730e6770dc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,7 +26,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.flashinfer import FlashInferMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -581,13 +580,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - if isinstance(attn_metadata, FlashInferMetadata): - # FIXME: abstract this away so it's not flashinfer-specific. - # This code should not rely on knowing the internals of the - # attention metadata. - logits_indices = attn_metadata.qo_indptr[1:] - 1 - else: - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = attn_metadata.query_start_loc[1:] - 1 # Hot-Swap lora model if self.lora_config: From 3779aafbfe18410603e8c4a7eb7284e1a44a6f12 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 01:26:19 +0000 Subject: [PATCH 13/24] lint --- vllm/v1/attention/backends/flashinfer.py | 46 ++++++++++++++---------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4a2d4e039f6e..e749b2b12dc7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -3,14 +3,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch from flashinfer import (BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 - from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.layer import Attention @@ -23,6 +21,8 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 + logger = init_logger(__name__) @@ -39,15 +39,15 @@ def get_name() -> str: return "FLASHINFER" @staticmethod - def get_impl_cls() -> FlashInferImpl: + def get_impl_cls() -> Type[FlashInferImpl]: return FlashInferImpl @staticmethod - def get_metadata_cls() -> AttentionMetadata: + def get_metadata_cls() -> Type[FlashInferMetadata]: return FlashInferMetadata @staticmethod - def get_builder_cls() -> FlashInferMetadataBuilder: + def get_builder_cls() -> Type[FlashInferMetadataBuilder]: return FlashInferMetadataBuilder @staticmethod @@ -242,17 +242,23 @@ def _get_cascade_wrapper(self): def _plan(self, attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) + get_per_layer_parameters(self.vllm_config)) if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], - [attn_metadata.shared_kv_page_indptr, - attn_metadata.paged_kv_indptr], - [attn_metadata.shared_kv_page_indices, - attn_metadata.paged_kv_indices], - [attn_metadata.shared_kv_last_page_len, - attn_metadata.paged_kv_last_page_len], + [ + attn_metadata.shared_kv_page_indptr, + attn_metadata.paged_kv_indptr + ], + [ + attn_metadata.shared_kv_page_indices, + attn_metadata.paged_kv_indices + ], + [ + attn_metadata.shared_kv_last_page_len, + attn_metadata.paged_kv_last_page_len + ], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -306,7 +312,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, dtype=torch.int32, device=device) shared_kv_page_indices = block_table[0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([0], dtype=torch.int32, + shared_kv_last_page_len = torch.tensor([0], + dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. block_table = block_table[:, num_common_kv_blocks:] @@ -318,15 +325,18 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, block_table_bounds = (seq_lens + page_size - 1) // page_size - mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, device=block_table.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table[mask] paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, + torch.zeros(1, + dtype=block_table_bounds.dtype, device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32)]) + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, @@ -466,7 +476,7 @@ def forward( assert attn_metadata.prefill_wrapper._causal assert attn_metadata.prefill_wrapper._window_left == window_left assert attn_metadata.prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) + self.logits_soft_cap or 0.0) assert attn_metadata.prefill_wrapper._sm_scale == self.scale output = attn_metadata.prefill_wrapper.run( query, From ec4cfbd30230bae36baddaf8ec68f29aaa31552c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 01:37:30 +0000 Subject: [PATCH 14/24] small --- vllm/v1/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e749b2b12dc7..c3f0f73d1d10 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,7 +10,7 @@ MultiLevelCascadeAttentionWrapper) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionType) from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger From 24d09d4c7f07381b691d138095cc79775b010bbc Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 01:42:02 +0000 Subject: [PATCH 15/24] clean --- vllm/v1/attention/backends/flashinfer.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c3f0f73d1d10..6498fe9bf4d7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -380,9 +380,6 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -390,25 +387,14 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - else: - self.sliding_window = (sliding_window - 1, 0) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = FlashInferBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashInfer. " - f"Supported head sizes are: {support_head_sizes}.") - if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " From 7bc66c37ad3c2ed428aac78ceb2e121b4f5496b5 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 02:05:07 +0000 Subject: [PATCH 16/24] bugfix --- vllm/v1/attention/backends/flashinfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 6498fe9bf4d7..5078286b334a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -304,6 +304,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, use_cascade = common_prefix_len > 0 if use_cascade: # Grab the blocks of the shared prefix from the first request. + assert common_prefix_len % page_size == 0 num_common_kv_blocks = common_prefix_len // page_size shared_qo_indptr = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -312,7 +313,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, dtype=torch.int32, device=device) shared_kv_page_indices = block_table[0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([0], + shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. From b8d775f6e2ec60e7bf807d43cb8ca32b79a2b2f3 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 02:16:49 +0000 Subject: [PATCH 17/24] bugfix --- vllm/v1/attention/backends/flashinfer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5078286b334a..3e35291d2811 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -301,6 +301,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + block_table_bounds = (seq_lens + page_size - 1) // page_size + use_cascade = common_prefix_len > 0 if use_cascade: # Grab the blocks of the shared prefix from the first request. @@ -318,14 +320,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, device=device) # Remove the blocks of the shared prefix from all requests. block_table = block_table[:, num_common_kv_blocks:] + block_table_bounds -= num_common_kv_blocks else: shared_qo_indptr = None shared_kv_page_indptr = None shared_kv_page_indices = None shared_kv_last_page_len = None - block_table_bounds = (seq_lens + page_size - 1) // page_size - mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, device=block_table.device).unsqueeze(0) From ce4d533ab7bcc356039ca189fc2c7cd0570476e4 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 02:31:54 +0000 Subject: [PATCH 18/24] cascade attn test for flashinfer --- tests/v1/e2e/test_cascade_attention.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index a8079dcce5e2..a93c7c76b1f5 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -1,13 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + from vllm import LLM, SamplingParams +from ...utils import fork_new_process_for_each_test + -def test_cascade_attention(example_system_message, monkeypatch): +@fork_new_process_for_each_test +@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN_VLLM_V1", "FLASHINFER"]) +def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n: Implement fibonacci sequence in Python.\n:" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") sampling_params = SamplingParams(temperature=0.0, max_tokens=100) From eeaeb2f01d7aac87ecda0e3cb31755cdfbd4a60d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 02:40:04 +0000 Subject: [PATCH 19/24] move use_cascade_attention to attention metadata builder --- vllm/v1/attention/backends/flash_attn.py | 7 +++---- vllm/v1/attention/backends/flashinfer.py | 17 +++++++++-------- vllm/v1/attention/backends/mla/common.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 2 +- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 353bf46d503e..2dba5731287d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -60,10 +60,6 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) - @dataclass class FlashAttentionMetadata: @@ -149,6 +145,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) return attn_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + class FlashAttentionImpl(AttentionImpl): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3e35291d2811..145ad60cd6a3 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -59,14 +59,6 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - #if self.runner.kv_cache_dtype != self.runner.model_config.dtype: - # # TODO: The cascade wrapper currently does not support setting - # # kv cache dtype to something different from query dtype. - # return False - return use_cascade_attention(*args, **kwargs) - @dataclass class PerLayerParameters: @@ -363,9 +355,18 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, ) + self._plan(attn_metadata) + return attn_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: + if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False + return use_cascade_attention(*args, **kwargs) + class FlashInferImpl(AttentionImpl): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 30bce5cc8b68..a05389eb2aa8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -270,10 +270,6 @@ def get_kv_cache_shape( def get_supported_head_sizes() -> List[int]: return [576] - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - @dataclass class MLACommonMetadata: @@ -525,6 +521,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, context_chunk_max_seq_lens=context_chunk_max_seq_lens, ) + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2730e6770dc3..227aca56b77a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -661,7 +661,7 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = self.attn_backend.use_cascade_attention( + use_cascade = self.attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, From 4153ac42c209fd88e9324bcf93c47ad86f39ddf9 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 1 Mar 2025 03:42:15 +0000 Subject: [PATCH 20/24] fix lint --- vllm/v1/attention/backends/flashinfer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 145ad60cd6a3..6e33f63f1e50 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -362,9 +362,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def use_cascade_attention(self, *args, **kwargs) -> bool: if self.runner.kv_cache_dtype != self.runner.model_config.dtype: - # TODO: The cascade wrapper currently does not support setting - # kv cache dtype to something different from query dtype. - return False + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False return use_cascade_attention(*args, **kwargs) From 95c402c0f6579a64c639199d8694ad232a359610 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 15 Apr 2025 21:10:24 +0000 Subject: [PATCH 21/24] Separate prefill and decode Signed-off-by: mgoin --- vllm/engine/arg_utils.py | 11 +- vllm/v1/attention/backends/flashinfer.py | 259 ++++++++++++++++++----- 2 files changed, 217 insertions(+), 53 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 32cb2e90af20..78ae6ff67f03 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1531,8 +1531,15 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No FlashInfer or XFormers so far. V1_BACKENDS = [ - "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "FLASHMLA", + "FLASHINFER", + "FLASHINFER_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 6e33f63f1e50..0bdadf6f497e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -3,12 +3,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Optional import torch -from flashinfer import (BatchPrefillWithPagedKVCacheWrapper, +from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention @@ -17,7 +19,7 @@ from vllm.v1.attention.backends.flash_attn import use_cascade_attention if TYPE_CHECKING: - from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -31,23 +33,23 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [64, 128, 256] @staticmethod def get_name() -> str: - return "FLASHINFER" + return "FLASHINFER_VLLM_V1" @staticmethod - def get_impl_cls() -> Type[FlashInferImpl]: + def get_impl_cls() -> type[FlashInferImpl]: return FlashInferImpl @staticmethod - def get_metadata_cls() -> Type[FlashInferMetadata]: + def get_metadata_cls() -> type[FlashInferMetadata]: return FlashInferMetadata @staticmethod - def get_builder_cls() -> Type[FlashInferMetadataBuilder]: + def get_builder_cls() -> type[FlashInferMetadataBuilder]: return FlashInferMetadataBuilder @staticmethod @@ -56,7 +58,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) @@ -73,14 +75,14 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: """ Scan all attention layers and determine some hyperparameters to use during `plan`. """ layers = vllm_config.compilation_config.static_forward_context - per_layer_params: Dict[str, PerLayerParameters] = {} + per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): assert isinstance(layer, Attention) @@ -101,7 +103,7 @@ def get_per_layer_parameters( def infer_global_hyperparameters( - per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: """ Currently, FlashInfer backend only support models in which all layers share the same values for the following hyperparameters: @@ -165,6 +167,12 @@ class FlashInferMetadata: slot_mapping: torch.Tensor + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + # For cascade attention. use_cascade: bool shared_qo_indptr: Optional[torch.Tensor] = None @@ -172,8 +180,9 @@ class FlashInferMetadata: shared_kv_page_indices: Optional[torch.Tensor] = None shared_kv_last_page_len: Optional[torch.Tensor] = None - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper = None - cascade_wrapper: MultiLevelCascadeAttentionWrapper = None + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -200,6 +209,7 @@ def __init__(self, runner: GPUModelRunner): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append + self._decode_wrapper = None # Wrapper for decode self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -208,8 +218,65 @@ def __init__(self, runner: GPUModelRunner): self.vllm_config = get_current_vllm_config() def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput): - pass + scheduler_output: SchedulerOutput) -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -225,6 +292,20 @@ def _get_prefill_wrapper(self): self._get_workspace_buffer(), "NHD") return self._prefill_wrapper + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( @@ -262,26 +343,67 @@ def _plan(self, attn_metadata: FlashInferMetadata): q_data_type=attn_metadata.q_data_type, ) else: - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - attn_metadata.prefill_wrapper.plan( - attn_metadata.qo_indptr, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len, - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, - ) + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if self._num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = self._num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert attn_metadata.qo_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_last_page_len[ + prefill_start:].shape[0] == self._num_prefills + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr = attn_metadata.qo_indptr[ + prefill_start:] - attn_metadata.qo_indptr[prefill_start] + attn_metadata.prefill_wrapper.plan( + qo_indptr, + attn_metadata.paged_kv_indptr[prefill_start:], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[prefill_start:], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + if self._num_decodes > 0: + attn_metadata.decode_wrapper = self._get_decode_wrapper() + attn_metadata.decode_wrapper.plan( + attn_metadata.paged_kv_indptr[:self._num_decodes + 1], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[:self._num_decodes], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): + assert self._num_decodes + self._num_prefills == num_reqs + assert (self._num_decode_tokens + + self._num_prefill_tokens == num_actual_tokens) page_size = self.runner.block_size device = self.runner.device qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( @@ -349,6 +471,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, data_type=self.runner.kv_cache_dtype, q_data_type=self.runner.dtype, slot_mapping=slot_mapping, + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + num_prefill_tokens=self._num_prefill_tokens, use_cascade=use_cascade, shared_qo_indptr=shared_qo_indptr, shared_kv_page_indptr=shared_kv_page_indptr, @@ -376,10 +502,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, + blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, ) -> None: @@ -440,6 +566,7 @@ def forward( # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. + num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. # NOTE(woosuk): Here, key and value are padded while slot_mapping is # not padded. However, we don't need to do key[:num_actual_tokens] and @@ -459,24 +586,54 @@ def forward( window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) - if not attn_metadata.use_cascade: - # Regular attention (common case). - assert attn_metadata.prefill_wrapper is not None - assert attn_metadata.prefill_wrapper._causal - assert attn_metadata.prefill_wrapper._window_left == window_left - assert attn_metadata.prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) - assert attn_metadata.prefill_wrapper._sm_scale == self.scale - output = attn_metadata.prefill_wrapper.run( - query, + # Inputs and outputs may be padded for CUDA graphs + query = query[:num_actual_tokens] + output_padded = output + output = output[:num_actual_tokens] + + if attn_metadata.use_cascade: + # Cascade attention (rare case). + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) + return output + + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if prefill_wrapper := attn_metadata.prefill_wrapper: + prefill_query = query[num_decode_tokens:] + assert prefill_query.shape[0] == num_prefill_tokens + assert prefill_wrapper is not None + assert prefill_wrapper._causal + assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( + prefill_query, kv_cache, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, - out=output, + out=output[num_decode_tokens:], + ) + + if decode_wrapper := attn_metadata.decode_wrapper: + decode_query = query[:num_decode_tokens] + assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None + assert decode_wrapper._window_left == window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], ) - return output - # Cascade attention (rare case). - assert attn_metadata.cascade_wrapper is not None - output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) - return output + return output_padded From 77f5b1faa4ad77001c0b29b69607dd3490f60f27 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Apr 2025 16:48:08 +0000 Subject: [PATCH 22/24] Update test cascade Signed-off-by: mgoin --- tests/v1/e2e/test_cascade_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index a93c7c76b1f5..48c265560348 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -8,7 +8,8 @@ @fork_new_process_for_each_test -@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN_VLLM_V1", "FLASHINFER"]) +@pytest.mark.parametrize("attn_backend", + ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n: Implement fibonacci sequence in Python.\n:" From b8ced05c11de256e19b89667dbd7536538908269 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Apr 2025 19:55:37 +0000 Subject: [PATCH 23/24] Updates Signed-off-by: mgoin --- vllm/engine/arg_utils.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 78ae6ff67f03..a9d693577561 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1529,7 +1529,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No FlashInfer or XFormers so far. + # No XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0bdadf6f497e..12dcd0354363 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -516,8 +516,10 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap From 44c871a39a37684f53b1d577a44ef49fefccae93 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Apr 2025 15:11:36 +0000 Subject: [PATCH 24/24] Update reorder logic Signed-off-by: mgoin --- vllm/v1/attention/backends/flashinfer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 12dcd0354363..17341ecfa4fe 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -254,20 +254,18 @@ def reorder_batch(self, input_batch: InputBatch, # the above loop num_decodes = len(decodes) num_prefills = len(prefills) - first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: break + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this