Skip to content

Commit 678db9b

Browse files
committed
[Performance] Remove input pads in cutlass_mla and optimize v_proj output reshape
Signed-off-by: Alexander Matveev <[email protected]>
1 parent c4cb0af commit 678db9b

File tree

2 files changed

+66
-19
lines changed

2 files changed

+66
-19
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ def __init__(
941941
qk_head_dim: int,
942942
v_head_dim: int,
943943
kv_b_proj: ColumnParallelLinear,
944+
q_pad_num_heads: Optional[int] = None,
944945
) -> None:
945946
if kv_sharing_target_layer_name is not None:
946947
raise NotImplementedError("KV sharing is not supported for MLA")
@@ -958,6 +959,7 @@ def __init__(
958959
self.qk_head_dim = qk_head_dim
959960
self.v_head_dim = v_head_dim
960961
self.kv_b_proj = kv_b_proj
962+
self.q_pad_num_heads = q_pad_num_heads
961963

962964
if use_flashinfer_prefill():
963965
logger.debug_once("Using FlashInfer prefill for MLA")
@@ -1133,7 +1135,7 @@ def _run_prefill_context_chunk_cudnn(self,
11331135
True, #Indicates actual_seq_lens are on GPU or CPU.
11341136
)
11351137

1136-
def _v_up_proj(self, x):
1138+
def _v_up_proj(self, x, out):
11371139
# Convert from (B, N, L) to (N, B, L)
11381140
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
11391141
if is_rocm_aiter_fp8bmm_enabled():
@@ -1145,12 +1147,24 @@ def _v_up_proj(self, x):
11451147
transpose_bm=True)
11461148
# Convert from (B, N, V) to (B, N * V)
11471149
x = x.reshape(-1, self.num_heads * self.v_head_dim)
1150+
# Copy result
1151+
out[:] = x
11481152
else:
1153+
# Convert from (B, N * V) to (N, B, V)
1154+
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
1155+
11491156
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1150-
x = torch.bmm(x, self.W_UV)
1157+
x = torch.bmm(x, self.W_UV,
1158+
out=out) # Reuse "out" to make it "hot"
1159+
11511160
# Convert from (N, B, V) to (B, N * V)
1152-
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1153-
return x
1161+
out_new = out.transpose(0, 1).reshape(
1162+
-1, self.num_heads * self.v_head_dim)
1163+
1164+
# Adjust output buffer shape back to the original (B, N * V)
1165+
N, B, V = out.shape
1166+
out.resize_((B, N * V))
1167+
out[:] = out_new # Copy result
11541168

11551169
def process_weights_after_loading(self, act_dtype: torch.dtype):
11561170

@@ -1558,6 +1572,15 @@ def forward(
15581572
# Convert from (B, N, P) to (N, B, P)
15591573
decode_q_nope = decode_q_nope.transpose(0, 1)
15601574

1575+
# Pads the head_dim if necessary (for the underlying kernel)
1576+
if self.q_pad_num_heads is not None:
1577+
B, N, L = decode_q_pe.shape
1578+
decode_pe_padded = decode_q_pe.new_empty(
1579+
(B, self.q_pad_num_heads, L))
1580+
decode_pe_padded.resize_((B, N, L))
1581+
decode_pe_padded[:] = decode_q_pe
1582+
decode_q_pe = decode_pe_padded
1583+
15611584
if is_rocm_aiter_fp8bmm_enabled():
15621585
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
15631586
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
@@ -1566,8 +1589,21 @@ def forward(
15661589
group_size=128,
15671590
transpose_bm=True)
15681591
else:
1592+
# Pads the head_dim if necessary (for the underlying kernel)
1593+
N, B, P = decode_q_nope.shape
1594+
_, _, L = self.W_UK_T.shape
1595+
if self.q_pad_num_heads is not None:
1596+
decode_ql_nope = decode_q_nope.new_empty(
1597+
(self.q_pad_num_heads, B, L))
1598+
decode_ql_nope.resize_((N, B, L))
1599+
1600+
else:
1601+
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
1602+
15691603
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1570-
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
1604+
decode_ql_nope = torch.bmm(decode_q_nope,
1605+
self.W_UK_T,
1606+
out=decode_ql_nope)
15711607
# Convert from (N, B, L) to (B, N, L)
15721608
decode_ql_nope = decode_ql_nope.transpose(0, 1)
15731609

@@ -1602,5 +1638,5 @@ def forward(
16021638
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
16031639

16041640
# v_up projection
1605-
output[:num_decode_tokens] = self._v_up_proj(attn_out)
1641+
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
16061642
return output_padded

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def ensure_size(self, attn_metadata: MLACommonMetadata,
7474

7575
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
7676

77+
MAX_HEADS = 128
78+
7779

7880
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
7981
can_return_lse_for_decode: bool = True
@@ -92,10 +94,18 @@ def __init__(
9294
kv_sharing_target_layer_name: Optional[str],
9395
# MLA Specific Arguments
9496
**mla_args) -> None:
95-
super().__init__(num_heads, head_size, scale, num_kv_heads,
96-
alibi_slopes, sliding_window, kv_cache_dtype,
97-
logits_soft_cap, attn_type,
98-
kv_sharing_target_layer_name, **mla_args)
97+
super().__init__(num_heads,
98+
head_size,
99+
scale,
100+
num_kv_heads,
101+
alibi_slopes,
102+
sliding_window,
103+
kv_cache_dtype,
104+
logits_soft_cap,
105+
attn_type,
106+
kv_sharing_target_layer_name,
107+
q_pad_num_heads=MAX_HEADS,
108+
**mla_args)
99109

100110
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
101111
if any(unsupported_features):
@@ -158,13 +168,15 @@ def _sm100_cutlass_mla_decode(
158168
MAX_HEADS = 128
159169
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
160170
if H < MAX_HEADS:
161-
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
162-
q_nope_padded[:, :H] = q_nope
163-
q_nope = q_nope_padded
171+
# Remove this section after all tests pass
172+
pass
173+
# q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
174+
# q_nope_padded[:, :H] = q_nope
175+
# q_nope = q_nope_padded
164176

165-
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
166-
q_pe_padded[:, :H] = q_pe
167-
q_pe = q_pe_padded
177+
# q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
178+
# q_pe_padded[:, :H] = q_pe
179+
# q_pe = q_pe_padded
168180

169181
assert len(page_table.shape) == 2
170182
B_block_table, block_num = page_table.shape
@@ -207,11 +219,10 @@ def _sm100_cutlass_mla_decode(
207219

208220
if H < MAX_HEADS:
209221
# Extract the subsets of the outputs
210-
returned_lse = lse[:, :H].contiguous(
211-
) if self.need_to_return_lse_for_decode else lse
222+
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
212223
out = out[:, :H]
213224

214-
return out, returned_lse
225+
return out, lse
215226

216227
def _forward_decode(
217228
self,

0 commit comments

Comments
 (0)