diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0b0f521672b0..526b792ab1f9 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -223,6 +223,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down try: @@ -471,18 +472,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int) -> M: assert self._num_decodes + self._num_prefills == num_reqs + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. device = self.runner.device - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, - non_blocking=True) block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + device, non_blocking=True) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + seq_lens = seq_lens_cpu.to(device, non_blocking=True) + max_query_len = seq_lens_cpu.max().item() + prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start @@ -490,24 +496,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] - context_lens = context_lens_cpu.to(device, non_blocking=True) + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and context_lens.max() > 0: + and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code - num_prefills_with_context = (context_lens > 0).sum().item() - # currently we allocate an equal amount of workspace for each # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = \ - self.chunked_prefill_workspace_size \ - // num_prefills_with_context + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle @@ -516,30 +520,35 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.page_size) assert max_context_chunk > 0 - num_chunks = cdiv(context_lens.max(), max_context_chunk) + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) # if `max_context_chunk = 256`, `num_chunks = 3`, and # `num_prefills_with_context = 4`, create a tensor that looks # like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32) \ + torch.arange(num_chunks, dtype=torch.int32) \ .unsqueeze(1).expand(-1, self._num_prefills) \ * max_context_chunk - chunk_ends = torch.min(context_lens.unsqueeze(0), + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, - dtype=torch.int32, - device=device).unsqueeze(-1) + + cu_seq_lens_cpu = torch.zeros(num_chunks, + self._num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) chunked_context_metadata = \ MLACommonPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=torch.cat( - [zero, _chunk_cu_seq_lens], dim=1), - starts=chunk_starts, + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), workspace=self.chunked_prefill_workspace, @@ -553,7 +562,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, block_table=block_table[reqs_start:, ...], query_start_loc=query_start_loc[reqs_start:] - query_start_loc[reqs_start], - max_query_len=seq_lens[reqs_start:].max().item(), + max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) @@ -629,7 +638,9 @@ def __init__( # already inside an attention custom op), pull out the forward # method from the rotary embedding and call it directly # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb._forward_method + self.rotary_emb = rotary_emb.forward_native + if current_platform.is_cuda(): + self.rotary_emb = rotary_emb.forward_cuda self.q_proj = q_proj self.kv_b_proj = kv_b_proj @@ -1043,17 +1054,20 @@ def forward( decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe.contiguous(), + decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: