@@ -941,6 +941,7 @@ def __init__(
941941 qk_head_dim : int ,
942942 v_head_dim : int ,
943943 kv_b_proj : ColumnParallelLinear ,
944+ q_pad_num_heads : Optional [int ] = None ,
944945 ) -> None :
945946 if kv_sharing_target_layer_name is not None :
946947 raise NotImplementedError ("KV sharing is not supported for MLA" )
@@ -958,6 +959,7 @@ def __init__(
958959 self .qk_head_dim = qk_head_dim
959960 self .v_head_dim = v_head_dim
960961 self .kv_b_proj = kv_b_proj
962+ self .q_pad_num_heads = q_pad_num_heads
961963
962964 if use_flashinfer_prefill ():
963965 logger .debug_once ("Using FlashInfer prefill for MLA" )
@@ -1133,7 +1135,7 @@ def _run_prefill_context_chunk_cudnn(self,
11331135 True , #Indicates actual_seq_lens are on GPU or CPU.
11341136 )
11351137
1136- def _v_up_proj (self , x ):
1138+ def _v_up_proj (self , x , out ):
11371139 # Convert from (B, N, L) to (N, B, L)
11381140 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
11391141 if is_rocm_aiter_fp8bmm_enabled ():
@@ -1145,12 +1147,24 @@ def _v_up_proj(self, x):
11451147 transpose_bm = True )
11461148 # Convert from (B, N, V) to (B, N * V)
11471149 x = x .reshape (- 1 , self .num_heads * self .v_head_dim )
1150+ # Copy result
1151+ out [:] = x
11481152 else :
1153+ # Convert from (B, N * V) to (N, B, V)
1154+ out = out .view (- 1 , self .num_heads , self .v_head_dim ).transpose (0 , 1 )
1155+
11491156 # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1150- x = torch .bmm (x , self .W_UV )
1157+ x = torch .bmm (x , self .W_UV ,
1158+ out = out ) # Reuse "out" to make it "hot"
1159+
11511160 # Convert from (N, B, V) to (B, N * V)
1152- x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .v_head_dim )
1153- return x
1161+ out_new = out .transpose (0 , 1 ).reshape (
1162+ - 1 , self .num_heads * self .v_head_dim )
1163+
1164+ # Adjust output buffer shape back to the original (B, N * V)
1165+ N , B , V = out .shape
1166+ out .resize_ ((B , N * V ))
1167+ out [:] = out_new # Copy result
11541168
11551169 def process_weights_after_loading (self , act_dtype : torch .dtype ):
11561170
@@ -1558,6 +1572,15 @@ def forward(
15581572 # Convert from (B, N, P) to (N, B, P)
15591573 decode_q_nope = decode_q_nope .transpose (0 , 1 )
15601574
1575+ # Pads the head_dim if necessary (for the underlying kernel)
1576+ if self .q_pad_num_heads is not None :
1577+ B , N , L = decode_q_pe .shape
1578+ decode_pe_padded = decode_q_pe .new_empty (
1579+ (B , self .q_pad_num_heads , L ))
1580+ decode_pe_padded .resize_ ((B , N , L ))
1581+ decode_pe_padded [:] = decode_q_pe
1582+ decode_q_pe = decode_pe_padded
1583+
15611584 if is_rocm_aiter_fp8bmm_enabled ():
15621585 # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
15631586 decode_ql_nope = aiter_triton_fp8_bmm (decode_q_nope ,
@@ -1566,8 +1589,21 @@ def forward(
15661589 group_size = 128 ,
15671590 transpose_bm = True )
15681591 else :
1592+ # Pads the head_dim if necessary (for the underlying kernel)
1593+ N , B , P = decode_q_nope .shape
1594+ _ , _ , L = self .W_UK_T .shape
1595+ if self .q_pad_num_heads is not None :
1596+ decode_ql_nope = decode_q_nope .new_empty (
1597+ (self .q_pad_num_heads , B , L ))
1598+ decode_ql_nope .resize_ ((N , B , L ))
1599+
1600+ else :
1601+ decode_ql_nope = decode_q_nope .new_empty ((N , B , L ))
1602+
15691603 # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1570- decode_ql_nope = torch .bmm (decode_q_nope , self .W_UK_T )
1604+ decode_ql_nope = torch .bmm (decode_q_nope ,
1605+ self .W_UK_T ,
1606+ out = decode_ql_nope )
15711607 # Convert from (N, B, L) to (B, N, L)
15721608 decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
15731609
@@ -1602,5 +1638,5 @@ def forward(
16021638 attn_out = cp_lse_ag_out_rs (attn_out , lse , get_dcp_group ())
16031639
16041640 # v_up projection
1605- output [: num_decode_tokens ] = self ._v_up_proj (attn_out )
1641+ self ._v_up_proj (attn_out , out = output [: num_decode_tokens ] )
16061642 return output_padded
0 commit comments