Skip to content

Commit 2c1c7df

Browse files
authored
[Models][Qwen] Replace pad with cat for better performance (#26486)
Signed-off-by: Lukas Geiger <[email protected]>
1 parent e246ad6 commit 2c1c7df

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

vllm/model_executor/models/dots_ocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def forward(
680680
dim=0,
681681
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
682682
)
683-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
683+
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
684684

685685
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
686686
for blk in self.blocks:

vllm/model_executor/models/ernie45_vl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,11 +574,12 @@ def forward(
574574
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
575575
).cumsum(dim=0, dtype=torch.int32)
576576

577+
zeros = cu_seqlens.new_zeros(1)
577578
if num_pad > 0:
578-
cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0)
579+
cu_seqlens = torch.cat([zeros, cu_seqlens, zeros])
579580
cu_seqlens[-1] = cu_seqlens[-2] + num_pad
580581
else:
581-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
582+
cu_seqlens = torch.cat([zeros, cu_seqlens])
582583

583584
# add batch size
584585
if hidden_states.ndim == 2:

vllm/model_executor/models/qwen3_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def forward(
539539
dim=0,
540540
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
541541
)
542-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
542+
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
543543

544544
hidden_states = hidden_states.unsqueeze(1)
545545
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)

vllm/model_executor/models/siglip2navit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def forward(
592592
# for more information
593593
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
594594
)
595-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
595+
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
596596

597597
reverse_indices = torch.argsort(window_index)
598598

0 commit comments

Comments
 (0)