Skip to content

Commit 54e6cd5

Browse files
committed
avoid a redundant contiguous() call for the end of the sm100 cutlass mla decode
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 0514ccd commit 54e6cd5

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,14 @@ def _sm100_cutlass_mla_decode(
210210
sm_scale,
211211
num_kv_splits,
212212
)
213-
returned_lse = lse[:, :H].contiguous(
214-
) if self.need_to_return_lse_for_decode else lse
215-
return out[:, :H].contiguous(), returned_lse
213+
214+
if H < MAX_HEADS:
215+
# Extract the subsets of the outputs
216+
returned_lse = lse[:, :H].contiguous(
217+
) if self.need_to_return_lse_for_decode else lse
218+
out = out[:, :H]
219+
220+
return out, returned_lse
216221

217222
def _sm100_forward_decode(
218223
self,

0 commit comments

Comments
 (0)