From e3c00a1faf1d7095180d0fb02780f5ba6b9cc831 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 10 Mar 2025 05:41:56 +0000 Subject: [PATCH 1/3] [Perf] Improve MLA on V1 Signed-off-by: simon-mo --- vllm/v1/attention/backends/mla/common.py | 62 ++++++++++++++---------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0b0f521672b0..d16cb7dbb1c8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -223,6 +223,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down try: @@ -471,18 +472,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int) -> M: assert self._num_decodes + self._num_prefills == num_reqs + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. device = self.runner.device - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, - non_blocking=True) block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + device, non_blocking=True) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + seq_lens = seq_lens_cpu.to(device, non_blocking=True) + max_query_len = seq_lens_cpu.max().item() + prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start @@ -490,24 +496,21 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] - context_lens = context_lens_cpu.to(device, non_blocking=True) + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and context_lens.max() > 0: + and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code - num_prefills_with_context = (context_lens > 0).sum().item() - # currently we allocate an equal amount of workspace for each # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = \ - self.chunked_prefill_workspace_size \ - // num_prefills_with_context + max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle @@ -516,30 +519,29 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.page_size) assert max_context_chunk > 0 - num_chunks = cdiv(context_lens.max(), max_context_chunk) + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) # if `max_context_chunk = 256`, `num_chunks = 3`, and # `num_prefills_with_context = 4`, create a tensor that looks # like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream requirements + # of `to_list`. chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32) \ + torch.arange(num_chunks, dtype=torch.int32) \ .unsqueeze(1).expand(-1, self._num_prefills) \ * max_context_chunk - chunk_ends = torch.min(context_lens.unsqueeze(0), + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, - dtype=torch.int32, - device=device).unsqueeze(-1) + + cu_seq_lens_cpu = torch.zeros(num_chunks, self._num_prefills + 1, dtype=torch.int32, pin_memory=True) + torch.cumsum(chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) chunked_context_metadata = \ MLACommonPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=torch.cat( - [zero, _chunk_cu_seq_lens], dim=1), - starts=chunk_starts, + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), workspace=self.chunked_prefill_workspace, @@ -553,7 +555,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, block_table=block_table[reqs_start:, ...], query_start_loc=query_start_loc[reqs_start:] - query_start_loc[reqs_start], - max_query_len=seq_lens[reqs_start:].max().item(), + max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) @@ -629,7 +631,9 @@ def __init__( # already inside an attention custom op), pull out the forward # method from the rotary embedding and call it directly # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb._forward_method + self.rotary_emb = rotary_emb.forward_native + if current_platform.is_cuda(): + self.rotary_emb = rotary_emb.forward_cuda self.q_proj = q_proj self.kv_b_proj = kv_b_proj @@ -1043,16 +1047,22 @@ def forward( decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) + + decode_q_pe_input = (decode_q_pe.clone().contiguous() if not decode_q_pe.is_contiguous() else decode_q_pe) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe_input, decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + + prefill_q_pe_input = (prefill_q_pe.clone().contiguous() if not prefill_q_pe.is_contiguous() else prefill_q_pe) + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, + attn_metadata.prefill.input_positions, prefill_q_pe_input, prefill_k_pe) # write the latent and rope to kv cache From 8cf800f17a718c4c1a39e09e520c8d6e6d7ed046 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 10 Mar 2025 05:55:52 +0000 Subject: [PATCH 2/3] fix lint Signed-off-by: simon-mo --- vllm/v1/attention/backends/mla/common.py | 26 +++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d16cb7dbb1c8..d9ef2b95b95b 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -510,7 +510,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle @@ -525,7 +526,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # `num_prefills_with_context = 4`, create a tensor that looks # like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - # Note(simon): this is done in CPU because of downstream requirements + # Note(simon): this is done in CPU because of downstream's # of `to_list`. chunk_starts = \ torch.arange(num_chunks, dtype=torch.int32) \ @@ -535,8 +536,14 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, self._num_prefills + 1, dtype=torch.int32, pin_memory=True) - torch.cumsum(chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + cu_seq_lens_cpu = torch.zeros(num_chunks, + self._num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) chunked_context_metadata = \ MLACommonPrefillMetadata.ChunkedContextMetadata( @@ -1048,10 +1055,13 @@ def forward( decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - decode_q_pe_input = (decode_q_pe.clone().contiguous() if not decode_q_pe.is_contiguous() else decode_q_pe) + decode_q_pe_input = (decode_q_pe.clone().contiguous() + if not decode_q_pe.is_contiguous() else + decode_q_pe) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe_input, decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe_input, + decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None @@ -1059,7 +1069,9 @@ def forward( .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe_input = (prefill_q_pe.clone().contiguous() if not prefill_q_pe.is_contiguous() else prefill_q_pe) + prefill_q_pe_input = (prefill_q_pe.clone().contiguous() + if not prefill_q_pe.is_contiguous() else + prefill_q_pe) prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, prefill_q_pe_input, From f8c28a4715a1324951389056b7de0fd026bd4c99 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 10 Mar 2025 15:39:08 +0000 Subject: [PATCH 3/3] simpler code from lucas Signed-off-by: simon-mo --- vllm/v1/attention/backends/mla/common.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d9ef2b95b95b..526b792ab1f9 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1055,12 +1055,8 @@ def forward( decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - decode_q_pe_input = (decode_q_pe.clone().contiguous() - if not decode_q_pe.is_contiguous() else - decode_q_pe) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe_input, + attn_metadata.decode.input_positions, decode_q_pe.contiguous(), decode_k_pe) if has_prefill: @@ -1069,13 +1065,9 @@ def forward( .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe_input = (prefill_q_pe.clone().contiguous() - if not prefill_q_pe.is_contiguous() else - prefill_q_pe) - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe_input, - prefill_k_pe) + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: