From c8481a63cc4ad703b44e97cfb2ccd0e47703774a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 21:06:33 +0000 Subject: [PATCH 1/8] fix IMA Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 36 ++++++++++++++++---- vllm/v1/attention/backends/mla/flashmla.py | 11 +++--- vllm/v1/attention/backends/mla/triton_mla.py | 8 +++-- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 824ffcfd61b..340fce42879 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -314,12 +314,19 @@ class MLACommonMetadata: num_decode_tokens: Optional[int] = None num_prefills: Optional[int] = None has_context: bool = False + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None context_chunk_seq_tot: Optional[list[int]] = None context_chunk_max_seq_lens: Optional[list[int]] = None chunked_prefill_workspace: Optional[torch.Tensor] = None + # Computed in __post_init__ + prefill_query_start_loc: Optional[torch.Tensor] = None + prefill_max_query_len: Optional[int] = None + decode_seq_lens: Optional[torch.Tensor] = None + decode_block_table: Optional[torch.Tensor] = None + def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -328,6 +335,18 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") + # Pre-compute prefill/decode tensor slices and other stats + if self.num_prefills is not None and self.num_prefills > 0: + assert self.num_decodes is not None and self.num_decodes > 0 + start = self.num_decodes # prefill_start + self.prefill_query_start_loc = \ + self.query_start_loc[start:] - self.query_start_loc[start] + self.prefill_max_query_len = self.seq_lens[start:].max().item() + + if self.num_decodes is not None and self.num_decodes > 0: + self.decode_seq_lens = self.seq_lens[:self.num_decodes] + self.decode_block_table = self.block_table[:self.num_decodes, ...] + T = TypeVar("T", bound=MLACommonMetadata) @@ -803,6 +822,8 @@ def _compute_prefill_context( assert attn_metadata.context_chunk_cu_seq_lens is not None assert attn_metadata.context_chunk_starts is not None assert attn_metadata.context_chunk_max_seq_lens is not None + assert attn_metadata.prefill_query_start_loc is not None + assert attn_metadata.prefill_max_query_len is not None output = None iters = len(attn_metadata.context_chunk_seq_tot) @@ -845,9 +866,9 @@ def _compute_prefill_context( q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_q=attn_metadata.prefill_query_start_loc, cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_q=attn_metadata.prefill_max_query_len, max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i], softmax_scale=self.scale, causal=False, # Context is unmasked @@ -881,6 +902,9 @@ def _forward_prefill( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, ) -> torch.Tensor: + assert attn_metadata.prefill_query_start_loc is not None + assert attn_metadata.prefill_max_query_len is not None + has_context = attn_metadata.has_context kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -898,10 +922,10 @@ def _forward_prefill( q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - max_seqlen_k=attn_metadata.max_seq_len, + cu_seqlens_q=attn_metadata.prefill_query_start_loc, + cu_seqlens_k=attn_metadata.prefill_query_start_loc, + max_seqlen_q=attn_metadata.prefill_max_query_len, + max_seqlen_k=attn_metadata.prefill_max_query_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b357d714241..59580c6d52e 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -60,7 +60,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, if m.num_decode_tokens is not None and m.num_decode_tokens > 0: m.decode_tile_scheduler_metadata, m.decode_num_splits = \ get_mla_metadata( - m.seq_lens[:m.num_decode_tokens], + m.decode_seq_lens, self.num_q_heads, 1, # MQA for the decode path ) @@ -115,6 +115,9 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode_block_table is not None + assert attn_metadata.decode_seq_lens is not None + if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 FlashMLA not yet supported") @@ -124,10 +127,8 @@ def _forward_decode( o, _ = flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=attn_metadata.block_table[:attn_metadata.num_decodes, - ...], - cache_seqlens=attn_metadata.seq_lens[:attn_metadata. - num_decode_tokens], + block_table=attn_metadata.decode_block_table, + cache_seqlens=attn_metadata.decode_seq_lens, head_dim_v=self.kv_lora_rank, tile_scheduler_metadata=attn_metadata. decode_tile_scheduler_metadata, diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3f9b349a5f0..c9d015b58c0 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -69,6 +69,9 @@ def _forward_decode( attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode_block_table is not None + assert attn_metadata.decode_seq_lens is not None + if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") @@ -104,7 +107,8 @@ def _forward_decode( # Run MQA decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - attn_metadata.block_table, attn_metadata.seq_lens, - attn_logits, num_kv_splits, self.scale, PAGE_SIZE) + attn_metadata.decode_block_table, + attn_metadata.decode_seq_lens, attn_logits, + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) From e9a8824bb505bd0d24dcb3cb34f51d59e4a84317 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 21:25:11 +0000 Subject: [PATCH 2/8] bugfix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 340fce42879..52496bf688e 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -337,7 +337,7 @@ def __post_init__(self): # Pre-compute prefill/decode tensor slices and other stats if self.num_prefills is not None and self.num_prefills > 0: - assert self.num_decodes is not None and self.num_decodes > 0 + assert self.num_decodes is not None start = self.num_decodes # prefill_start self.prefill_query_start_loc = \ self.query_start_loc[start:] - self.query_start_loc[start] From b87185c4b68b5c6e1f2da6cf992866751b59c385 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 22:51:57 +0000 Subject: [PATCH 3/8] working Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 45 ++++++++++++++++-------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 52496bf688e..9e3d7cd234f 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -302,6 +302,13 @@ class MLACommonMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + has_context: bool + # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -309,21 +316,18 @@ class MLACommonMetadata: head_dim: Optional[int] = None # New for MLA (compared to FlashAttention) - # For chunked prefill - num_decodes: Optional[int] = None - num_decode_tokens: Optional[int] = None - num_prefills: Optional[int] = None - has_context: bool = False - + # For handling chunked prefill context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None context_chunk_seq_tot: Optional[list[int]] = None context_chunk_max_seq_lens: Optional[list[int]] = None chunked_prefill_workspace: Optional[torch.Tensor] = None - # Computed in __post_init__ + # New for MLA (compared to FlashAttention) + # For handling prefill decode split prefill_query_start_loc: Optional[torch.Tensor] = None prefill_max_query_len: Optional[int] = None + prefill_block_table: Optional[torch.Tensor] = None decode_seq_lens: Optional[torch.Tensor] = None decode_block_table: Optional[torch.Tensor] = None @@ -342,6 +346,7 @@ def __post_init__(self): self.prefill_query_start_loc = \ self.query_start_loc[start:] - self.query_start_loc[start] self.prefill_max_query_len = self.seq_lens[start:].max().item() + self.prefill_block_table = self.block_table[start:, ...] if self.num_decodes is not None and self.num_decodes > 0: self.decode_seq_lens = self.seq_lens[:self.num_decodes] @@ -453,6 +458,8 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int) -> T: + assert self._num_decodes + self._num_prefills == num_reqs + device = self.runner.device max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( @@ -473,19 +480,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_computed_tokens_cpu_tensor = \ self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - context_lens_tensor = \ - num_computed_tokens_cpu_tensor.to(device, non_blocking=True) + prefill_context_lens_tensor = \ + num_computed_tokens_cpu_tensor[self._num_decodes:]\ + .to(device, non_blocking=True) + + has_context = False if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and context_lens_tensor[self._num_decodes:].max() > 0: + and prefill_context_lens_tensor.max() > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section in # the comment at the top of the file before trying to understand # the following code - self.has_context = True + has_context = True num_prefills_with_context = \ - (context_lens_tensor[self._num_decodes:] > 0).sum().item() + (prefill_context_lens_tensor > 0).sum().item() # currently we allocate an equal amount of workspace for each # prefill in the batch, we could probably use a more advanced @@ -499,7 +510,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 - num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + num_chunks = cdiv(prefill_context_lens_tensor.max(), + max_context_chunk) # if `max_context_chunk = 256`, `num_chunks = 3`, and # `num_prefills_with_context = 4`, create a tensor that looks like @@ -508,7 +520,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, torch.arange(num_chunks, device=device, dtype=torch.int32) \ .unsqueeze(1).expand(-1, self._num_prefills) \ * max_context_chunk - chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \ + chunk_ends = torch.min(prefill_context_lens_tensor \ .unsqueeze(0), context_chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( @@ -537,10 +549,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_decodes=self._num_decodes, num_decode_tokens=self._num_decode_tokens, num_prefills=self._num_prefills, + has_context=has_context, context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, context_chunk_starts=context_chunk_starts, context_chunk_seq_tot=context_chunk_seq_tot, context_chunk_max_seq_lens=context_chunk_max_seq_lens, + chunked_prefill_workspace=self.chunked_prefill_workspace, ) @@ -824,6 +838,7 @@ def _compute_prefill_context( assert attn_metadata.context_chunk_max_seq_lens is not None assert attn_metadata.prefill_query_start_loc is not None assert attn_metadata.prefill_max_query_len is not None + assert attn_metadata.prefill_block_table is not None output = None iters = len(attn_metadata.context_chunk_seq_tot) @@ -837,7 +852,7 @@ def _compute_prefill_context( ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, - block_table=attn_metadata.block_table, + block_table=attn_metadata.prefill_block_table, cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i], batch_size=attn_metadata.num_prefills, seq_starts=attn_metadata.context_chunk_starts[i], From 76904b223ffb87b1998d7aefdcaaf1207d38b0ce Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 02:14:44 +0000 Subject: [PATCH 4/8] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 285 +++++++++---------- vllm/v1/attention/backends/mla/flashmla.py | 10 +- vllm/v1/attention/backends/mla/triton_mla.py | 7 +- 3 files changed, 147 insertions(+), 155 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 9e3d7cd234f..5c49886f245 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -281,11 +281,6 @@ class MLACommonMetadata: NOTE: Please read the comment at the top of the file before trying to understand this class """ - # New for MLA (compared to FlashAttention) - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -295,11 +290,7 @@ class MLACommonMetadata: # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) @@ -307,7 +298,6 @@ class MLACommonMetadata: num_decodes: int num_decode_tokens: int num_prefills: int - has_context: bool # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -315,21 +305,38 @@ class MLACommonMetadata: # The dimension of the attention heads head_dim: Optional[int] = None - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - context_chunk_cu_seq_lens: Optional[torch.Tensor] = None - context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[list[int]] = None - context_chunk_max_seq_lens: Optional[list[int]] = None - chunked_prefill_workspace: Optional[torch.Tensor] = None - - # New for MLA (compared to FlashAttention) - # For handling prefill decode split - prefill_query_start_loc: Optional[torch.Tensor] = None - prefill_max_query_len: Optional[int] = None - prefill_block_table: Optional[torch.Tensor] = None - decode_seq_lens: Optional[torch.Tensor] = None - decode_block_table: Optional[torch.Tensor] = None + @dataclass + class PrefillMetadata: + """ Prefill Specific Metadata """ + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + @dataclass + class DecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + + decode: Optional[DecodeMetadata] = None + prefill: Optional[PrefillMetadata] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -339,19 +346,6 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - # Pre-compute prefill/decode tensor slices and other stats - if self.num_prefills is not None and self.num_prefills > 0: - assert self.num_decodes is not None - start = self.num_decodes # prefill_start - self.prefill_query_start_loc = \ - self.query_start_loc[start:] - self.query_start_loc[start] - self.prefill_max_query_len = self.seq_lens[start:].max().item() - self.prefill_block_table = self.block_table[start:, ...] - - if self.num_decodes is not None and self.num_decodes > 0: - self.decode_seq_lens = self.seq_lens[:self.num_decodes] - self.decode_block_table = self.block_table[:self.num_decodes, ...] - T = TypeVar("T", bound=MLACommonMetadata) @@ -364,7 +358,7 @@ class MLACommonMetadataBuilder(Generic[T]): def __init__(self, runner: "GPUModelRunner", - cls: Optional[type[T]] = None): + cls: Optional[type[MLACommonMetadata]] = None): self.cls = cls if cls is not None else MLACommonMetadata self.runner = runner scheduler_config = runner.scheduler_config @@ -461,7 +455,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert self._num_decodes + self._num_prefills == num_reqs device = self.runner.device - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( device, non_blocking=True) seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, @@ -473,88 +466,98 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - context_chunk_cu_seq_lens = None - context_chunk_starts = None - context_chunk_seq_tot = None - context_chunk_max_seq_lens = None - - num_computed_tokens_cpu_tensor = \ - self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - prefill_context_lens_tensor = \ - num_computed_tokens_cpu_tensor[self._num_decodes:]\ - .to(device, non_blocking=True) - - has_context = False - - if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and prefill_context_lens_tensor.max() > 0: - - # NOTE: it is recommend you read the `Chunked Prefill` section in - # the comment at the top of the file before trying to understand - # the following code - - has_context = True - - num_prefills_with_context = \ - (prefill_context_lens_tensor > 0).sum().item() - - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = \ - self.chunked_prefill_workspace_size // num_prefills_with_context - - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, self.page_size) - assert max_context_chunk > 0 - num_chunks = cdiv(prefill_context_lens_tensor.max(), - max_context_chunk) - - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - context_chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ - * max_context_chunk - chunk_ends = torch.min(prefill_context_lens_tensor \ - .unsqueeze(0), context_chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) - _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \ - .unsqueeze(-1) - context_chunk_cu_seq_lens = \ - torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) - context_chunk_max_seq_lens = \ - chunk_seq_lens.max(dim=1).values.tolist() - context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() - assert max(context_chunk_seq_tot) <= \ - self.chunked_prefill_workspace_size + prefill_metadata = None + if self._num_prefills > 0: + start = self._num_decodes # prefill_start + + context_lens_cpu = self.runner.input_batch.\ + num_computed_tokens_cpu_tensor[start:num_reqs] + context_lens = context_lens_cpu.to(device, non_blocking=True) + + chunked_context_metadata = None + if self.chunked_prefill_enabled and self._num_prefills > 0 \ + and context_lens.max() > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + num_prefills_with_context = (context_lens > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.chunked_prefill_workspace_size \ + // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self._num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, + dtype=torch.int32, + device=device).unsqueeze(-1) + + chunked_context_metadata = self.cls.\ + PrefillMetadata.ChunkedContextMetadata( # type: ignore + cu_seq_lens=torch.cat( + [zero, _chunk_cu_seq_lens], dim=1), + starts=chunk_starts, + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + workspace=self.chunked_prefill_workspace, + ) + + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size + + prefill_metadata = self.cls.PrefillMetadata( # type: ignore + input_positions=input_positions[self._num_decode_tokens:], + block_table=block_table[start:, ...], + query_start_loc=query_start_loc[start:] - + query_start_loc[start], + max_query_len=seq_lens[start:].max().item(), + chunked_context=chunked_context_metadata, + ) + + decode_metadata = None + if self._num_decodes > 0: + decode_metadata = self.cls.DecodeMetadata( # type: ignore + input_positions=input_positions[:self._num_decode_tokens], + block_table=block_table[:self._num_decodes, ...], + seq_lens=seq_lens[:self._num_decodes], + ) return self.cls( - input_positions=input_positions, num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific num_decodes=self._num_decodes, num_decode_tokens=self._num_decode_tokens, num_prefills=self._num_prefills, - has_context=has_context, - context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, - context_chunk_starts=context_chunk_starts, - context_chunk_seq_tot=context_chunk_seq_tot, - context_chunk_max_seq_lens=context_chunk_max_seq_lens, - chunked_prefill_workspace=self.chunked_prefill_workspace, + prefill=prefill_metadata, + decode=decode_metadata, ) @@ -831,31 +834,24 @@ def _compute_prefill_context( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, ): - assert attn_metadata.num_prefills is not None - assert attn_metadata.context_chunk_seq_tot is not None - assert attn_metadata.context_chunk_cu_seq_lens is not None - assert attn_metadata.context_chunk_starts is not None - assert attn_metadata.context_chunk_max_seq_lens is not None - assert attn_metadata.prefill_query_start_loc is not None - assert attn_metadata.prefill_max_query_len is not None - assert attn_metadata.prefill_block_table is not None + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None output = None - iters = len(attn_metadata.context_chunk_seq_tot) - - assert attn_metadata.chunked_prefill_workspace is not None - workspace = attn_metadata.chunked_prefill_workspace + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace for i in range(iters): - toks = attn_metadata.context_chunk_seq_tot[i] + toks = prefill_metadata.chunked_context.seq_tot[i] ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, - block_table=attn_metadata.prefill_block_table, - cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i], + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, - seq_starts=attn_metadata.context_chunk_starts[i], + seq_starts=prefill_metadata.chunked_context.starts[i], ) kv_c_normed = workspace[:toks]\ @@ -881,10 +877,10 @@ def _compute_prefill_context( q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.prefill_query_start_loc, - cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=attn_metadata.prefill_max_query_len, - max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i], + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], softmax_scale=self.scale, causal=False, # Context is unmasked return_softmax_lse=True, @@ -917,10 +913,9 @@ def _forward_prefill( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, ) -> torch.Tensor: - assert attn_metadata.prefill_query_start_loc is not None - assert attn_metadata.prefill_max_query_len is not None + assert attn_metadata.prefill is not None - has_context = attn_metadata.has_context + has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -937,10 +932,10 @@ def _forward_prefill( q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.prefill_query_start_loc, - cu_seqlens_k=attn_metadata.prefill_query_start_loc, - max_seqlen_q=attn_metadata.prefill_max_query_len, - max_seqlen_k=attn_metadata.prefill_max_query_len, + cu_seqlens_q=attn_metadata.prefill.query_start_loc, + cu_seqlens_k=attn_metadata.prefill.query_start_loc, + max_seqlen_q=attn_metadata.prefill.max_query_len, + max_seqlen_k=attn_metadata.prefill.max_query_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, @@ -1005,7 +1000,6 @@ def forward( # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ @@ -1017,28 +1011,27 @@ def forward( decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] decode_k_pe = k_pe[:num_decode_tokens] - decode_input_positions = \ - attn_metadata.input_positions[:num_decode_tokens] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] - prefill_input_positions = \ - attn_metadata.input_positions[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] if has_decode: + assert attn_metadata.decode is not None decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - decode_input_positions, decode_q_pe, decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) if has_prefill: + assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - prefill_input_positions, prefill_q_pe, prefill_k_pe) + attn_metadata.prefill.input_positions, prefill_q_pe, + prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 59580c6d52e..2aec945282f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -58,9 +58,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len) if m.num_decode_tokens is not None and m.num_decode_tokens > 0: + assert m.decode is not None m.decode_tile_scheduler_metadata, m.decode_num_splits = \ get_mla_metadata( - m.decode_seq_lens, + m.decode.seq_lens, self.num_q_heads, 1, # MQA for the decode path ) @@ -115,8 +116,7 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode_block_table is not None - assert attn_metadata.decode_seq_lens is not None + assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 FlashMLA not yet supported") @@ -127,8 +127,8 @@ def _forward_decode( o, _ = flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=attn_metadata.decode_block_table, - cache_seqlens=attn_metadata.decode_seq_lens, + block_table=attn_metadata.decode.block_table, + cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, tile_scheduler_metadata=attn_metadata. decode_tile_scheduler_metadata, diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index c9d015b58c0..cef7a3a9a72 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -69,8 +69,7 @@ def _forward_decode( attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode_block_table is not None - assert attn_metadata.decode_seq_lens is not None + assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") @@ -107,8 +106,8 @@ def _forward_decode( # Run MQA decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - attn_metadata.decode_block_table, - attn_metadata.decode_seq_lens, attn_logits, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) From 060c26b6cb5fceed0c4b25c4ae97c2b8bd9ebc0f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 02:33:50 +0000 Subject: [PATCH 5/8] type hinting fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 107 ++++++++++++--------- vllm/v1/attention/backends/mla/flashmla.py | 23 +++-- 2 files changed, 73 insertions(+), 57 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5c49886f245..1a50f12b6a6 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -275,7 +275,42 @@ def use_cascade_attention(*args, **kwargs) -> bool: @dataclass -class MLACommonMetadata: +class MLACommonPrefillMetadata: + """ Prefill Specific Metadata """ + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class MLACommonDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + + +D = TypeVar("D", bound=MLACommonDecodeMetadata) + + +@dataclass +class MLACommonMetadata(Generic[D]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to @@ -305,38 +340,8 @@ class MLACommonMetadata: # The dimension of the attention heads head_dim: Optional[int] = None - @dataclass - class PrefillMetadata: - """ Prefill Specific Metadata """ - - @dataclass - class ChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - query_start_loc: torch.Tensor - max_query_len: int - chunked_context: Optional[ChunkedContextMetadata] = None - - @dataclass - class DecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - - decode: Optional[DecodeMetadata] = None - prefill: Optional[PrefillMetadata] = None + decode: Optional[D] = None + prefill: Optional[MLACommonPrefillMetadata] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -347,19 +352,25 @@ def __post_init__(self): f"received {self.head_dim}.") -T = TypeVar("T", bound=MLACommonMetadata) +M = TypeVar("M", bound=MLACommonMetadata) -class MLACommonMetadataBuilder(Generic[T]): +class MLACommonMetadataBuilder(Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ - def __init__(self, - runner: "GPUModelRunner", - cls: Optional[type[MLACommonMetadata]] = None): - self.cls = cls if cls is not None else MLACommonMetadata + def __init__( + self, + runner: "GPUModelRunner", + metadata_cls: Optional[type[M]] = None, + decode_metadata_cls: Optional[type[D]] = None, + ): + self.metadata_cls = metadata_cls \ + if metadata_cls is not None else MLACommonMetadata + self.decode_metadata_cls = decode_metadata_cls \ + if decode_metadata_cls is not None else MLACommonDecodeMetadata self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -451,7 +462,7 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_prefill_tokens = num_prefill_tokens def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> T: + common_prefix_len: int) -> M: assert self._num_decodes + self._num_prefills == num_reqs device = self.runner.device @@ -517,8 +528,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, dtype=torch.int32, device=device).unsqueeze(-1) - chunked_context_metadata = self.cls.\ - PrefillMetadata.ChunkedContextMetadata( # type: ignore + chunked_context_metadata = \ + MLACommonPrefillMetadata.ChunkedContextMetadata( cu_seq_lens=torch.cat( [zero, _chunk_cu_seq_lens], dim=1), starts=chunk_starts, @@ -530,7 +541,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - prefill_metadata = self.cls.PrefillMetadata( # type: ignore + prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[self._num_decode_tokens:], block_table=block_table[start:, ...], query_start_loc=query_start_loc[start:] - @@ -541,13 +552,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: - decode_metadata = self.cls.DecodeMetadata( # type: ignore + decode_metadata = self.decode_metadata_cls( input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) - return self.cls( + return self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -561,7 +572,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): +class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -968,7 +979,7 @@ def _forward_decode( q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: M, ) -> torch.Tensor: raise NotImplementedError @@ -979,7 +990,7 @@ def forward( k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2aec945282f..26a5d94340f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -11,6 +11,7 @@ is_flashmla_supported) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) @@ -38,16 +39,20 @@ def get_impl_cls() -> type["FlashMLAImpl"]: @dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor] + num_splits: torch.Tensor + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): + pass class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): def __init__(self, runner): - super().__init__(runner, cls=FlashMLAMetadata) + super().__init__(runner, decode_metadata_cls=FlashMLADecodeMetadata) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) @@ -59,7 +64,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, if m.num_decode_tokens is not None and m.num_decode_tokens > 0: assert m.decode is not None - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + m.decode.tile_scheduler_metadata, m.decode.num_splits = \ get_mla_metadata( m.decode.seq_lens, self.num_q_heads, @@ -130,9 +135,9 @@ def _forward_decode( block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata. - decode_tile_scheduler_metadata, - num_splits=attn_metadata.decode_num_splits, + tile_scheduler_metadata=attn_metadata.decode. + tile_scheduler_metadata, + num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, ) From be12358434ff628f04d020efa42bd326fd4bfa0c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 02:46:46 +0000 Subject: [PATCH 6/8] override build_decode Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 21 +++++++------ vllm/v1/attention/backends/mla/flashmla.py | 34 ++++++++++++---------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1a50f12b6a6..830dc33b8ed 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -361,16 +361,11 @@ class MLACommonMetadataBuilder(Generic[M]): understand this class """ - def __init__( - self, - runner: "GPUModelRunner", - metadata_cls: Optional[type[M]] = None, - decode_metadata_cls: Optional[type[D]] = None, - ): + def __init__(self, + runner: "GPUModelRunner", + metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - self.decode_metadata_cls = decode_metadata_cls \ - if decode_metadata_cls is not None else MLACommonDecodeMetadata self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -461,6 +456,14 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, seq_lens: torch.Tensor): + return MLACommonDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + ) + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int) -> M: assert self._num_decodes + self._num_prefills == num_reqs @@ -552,7 +555,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: - decode_metadata = self.decode_metadata_cls( + decode_metadata = self._build_decode( input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 26a5d94340f..d5bf9cd22f1 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -52,26 +52,28 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): def __init__(self, runner): - super().__init__(runner, decode_metadata_cls=FlashMLADecodeMetadata) + super().__init__(runner) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): - m = super().build(num_reqs, num_actual_tokens, max_query_len, - common_prefix_len) - - if m.num_decode_tokens is not None and m.num_decode_tokens > 0: - assert m.decode is not None - m.decode.tile_scheduler_metadata, m.decode.num_splits = \ - get_mla_metadata( - m.decode.seq_lens, - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = \ + get_mla_metadata( + seq_lens, + self.num_q_heads, + 1, # MQA for the decode path + ) + + return FlashMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + ) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): From 43ca8fbb3363ec572ce071ef48115fd56bbaca27 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 22:42:37 +0000 Subject: [PATCH 7/8] fix swap Signed-off-by: Lucas Wilkinson --- tests/v1/worker/test_gpu_input_batch.py | 90 +++++++++++++++++++++++- vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/attention/backends/mla/common.py | 23 +++--- vllm/v1/worker/gpu_input_batch.py | 25 ++++++- vllm/v1/worker/gpu_model_runner.py | 6 +- 5 files changed, 133 insertions(+), 15 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 72ec7370115..5f0cb1d3d3b 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import inspect from typing import Optional import numpy as np @@ -9,7 +10,8 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, + InputBatch) VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -20,6 +22,34 @@ MAX_NUM_PROMPT_TOKENS = 64 +def _compare_objs(obj1, obj2): + attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) + attr_names = set([ + a[0] for a in attrs + if not (a[0].startswith('__') and a[0].endswith('__')) + ]) + for attr_name in attr_names: + a = getattr(obj1, attr_name) + b = getattr(obj2, attr_name) + + is_same = False + if isinstance(a, torch.Tensor): + if (a.numel() == 0 or b.numel() == 0): + is_same = (a.numel() == 0 and b.numel() == 0) + elif torch.allclose(a, b): + is_same = True + elif isinstance(a, np.ndarray): + if np.allclose(a, b): + is_same = True + elif isinstance(a, (BlockTable, SamplingMetadata)): + _compare_objs(a, b) + is_same = True # if we make it here must be same + elif a == b: + is_same = True + assert is_same, f"Attribute {attr_name} is different"\ + f" in {obj1} and {obj2}: {a} != {b}" + + def _remove_requests( input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]: @@ -254,3 +284,61 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, sampling_metadata.allowed_token_ids_mask) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("swap_list", [((0, 1), )]) +def test_swap_states_in_input_batch(device: str, batch_size: int, + swap_list: list): + """ + Tests the logic for managing sampling metadata in the InputBatch. + + This test involves adding a set of requests to the InputBatch, + followed by removing a subset of them. Afterward, the batch is compacted, + and the `make_sampling_metadata` method is invoked on the batch. The + output of `make_sampling_metadata` is then compared against the expected + results to ensure correctness. + """ + input_batch: InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024, + ) + ref_input_batch: InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024, + ) + + reqs: list[CachedRequestState] = [] + req_id_reqs = {} + req_id_output_token_ids = {} + # Add requests + for req_index in range(batch_size): + req: CachedRequestState = _construct_cached_request_state(req_index) + input_batch.add_request(req, req_index) + reqs.append(req) + req_id_reqs[req.req_id] = req + req_id_output_token_ids[req.req_id] = req.output_token_ids + + reordered_reqs = reqs.copy() + for swap_pair in swap_list: + reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \ + reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]] + input_batch.swap_states(swap_pair[0], swap_pair[1]) + + for req_index in range(batch_size): + req = reordered_reqs[req_index] + ref_input_batch.add_request(req, req_index) + + input_batch.refresh_sampling_metadata() + ref_input_batch.refresh_sampling_metadata() + + _compare_objs(input_batch, ref_input_batch) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8bf7f3587bc..e7c2fd412eb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -100,8 +100,8 @@ def __init__(self, runner: "GPUModelRunner"): self.runner = runner def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput"): - pass + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 830dc33b8ed..a22fa1bcd41 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -399,7 +399,7 @@ def __init__(self, self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput"): + scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -437,14 +437,18 @@ def reorder_batch(self, input_batch: "InputBatch", num_decodes = len(decodes) num_prefills = len(prefills) first_prefill = 0 + modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch if decodes[num_decodes - i] >= num_decodes: + print("Reordering ", prefills[first_prefill], + decodes[num_decodes - i]) input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 + modified_batch = True else: break @@ -456,6 +460,8 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens + return modified_batch + def _build_decode(self, input_positions: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( @@ -482,10 +488,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = None if self._num_prefills > 0: - start = self._num_decodes # prefill_start + reqs_start = self._num_decodes # prefill_start + tokens_start = self._num_decode_tokens context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[start:num_reqs] + num_computed_tokens_cpu_tensor[reqs_start:num_reqs] context_lens = context_lens_cpu.to(device, non_blocking=True) chunked_context_metadata = None @@ -545,11 +552,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[self._num_decode_tokens:], - block_table=block_table[start:, ...], - query_start_loc=query_start_loc[start:] - - query_start_loc[start], - max_query_len=seq_lens[start:].max().item(), + input_positions=input_positions[tokens_start:], + block_table=block_table[reqs_start:, ...], + query_start_loc=query_start_loc[reqs_start:] - + query_start_loc[reqs_start], + max_query_len=seq_lens[reqs_start:].max().item(), chunked_context=chunked_context_metadata, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b0b218d92b9..fb6a21f60db 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -374,8 +374,6 @@ def swap_states(self, i1: int, i2: int) -> None: self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] self.num_tokens[i1], self.num_tokens[i2] =\ self.num_tokens[i2], self.num_tokens[i1] - self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ - self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ @@ -397,24 +395,47 @@ def swap_states(self, i1: int, i2: int) -> None: self.min_p_cpu[i1], self.min_p_cpu[i2] =\ self.min_p_cpu[i2], self.min_p_cpu[i1] + # NOTE: the following is unsafe + # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ + # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] + # instead, we need to temporiarily copy the data for one of the indices + # TODO(lucas): optimize this by only copying valid indices + tmp = self.token_ids_cpu[i1, ...].copy() + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp + g1 = self.generators.get(i1) g2 = self.generators.get(i2) if g1 is not None: self.generators[i2] = g1 + else: + self.generators.pop(i2, None) if g2 is not None: self.generators[i1] = g2 + else: + self.generators.pop(i1, None) t1 = self.min_tokens.get(i1) t2 = self.min_tokens.get(i2) if t1 is not None: self.min_tokens[i2] = t1 + else: + self.min_tokens.pop(i2, None) if t2 is not None: self.min_tokens[i1] = t2 + else: + self.min_tokens.pop(i1, None) self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i2], self.request_lora_mapping[i1] self.logit_bias[i1], self.logit_bias[i2] =\ self.logit_bias[i2], self.logit_bias[i1] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[i1], \ + self.allowed_token_ids_mask_cpu_tensor[i2] =\ + self.allowed_token_ids_mask_cpu_tensor[i2], \ + self.allowed_token_ids_mask_cpu_tensor[i1] self.block_table.swap_row(i1, i2) def condense(self, empty_req_indices: list[int]) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a1fb0514c3..de0b78e7449 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -455,8 +455,10 @@ def _prepare_inputs( # Some attention backends (namely MLA) may want to separate requests # based on if the attention computation will be compute-bound or # memory-bound. This gives them a hook to do that. - self.attn_metadata_builder.reorder_batch(self.input_batch, - scheduler_output) + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. From 6557ea20131cfb2ca2c5ccd5d265fe00e67833b7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 23:19:19 +0000 Subject: [PATCH 8/8] remove debug Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a22fa1bcd41..c98262eea1e 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -443,8 +443,6 @@ def reorder_batch(self, input_batch: "InputBatch", # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch if decodes[num_decodes - i] >= num_decodes: - print("Reordering ", prefills[first_prefill], - decodes[num_decodes - i]) input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1