Skip to content

Commit 3c22242

Browse files
DarkLight1337lk-chen
authored andcommitted
[Misc] Clean up scatter_patch_features (vllm-project#15559)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 2574c70 commit 3c22242

File tree

6 files changed

+84
-138
lines changed

6 files changed

+84
-138
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
# yapf: enable
3131
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3232
from vllm.sequence import IntermediateTensors
33-
from vllm.utils import flatten_2d_lists
3433

3534
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
3635
SupportsMultiModal, SupportsPP)
@@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
6059
A boolean mask indicating which image embeddings correspond
6160
to patch tokens.
6261
63-
Shape: `(batch_size, num_images, num_embeds)`
62+
Shape: `(batch_size * num_images, num_embeds)`
6463
"""
6564

6665

@@ -593,6 +592,7 @@ def _parse_and_validate_image_input(
593592

594593
pixel_values = flatten_bn(pixel_values, concat=True)
595594
num_crops = flatten_bn(num_crops, concat=True)
595+
embed_is_patch = flatten_bn(embed_is_patch)
596596

597597
return Gemma3ImagePixelInputs(
598598
type="pixel_values",
@@ -635,14 +635,10 @@ def get_multimodal_embeddings(
635635

636636
image_features = self._process_image_input(image_input)
637637

638-
if kwargs.get("v0_path", False):
639-
return image_features
640-
641-
return flatten_2d_lists(
642-
scatter_patch_features(*args) for args in zip(
643-
image_features,
644-
image_input["embed_is_patch"],
645-
))
638+
return scatter_patch_features(
639+
image_features,
640+
image_input["embed_is_patch"],
641+
)
646642

647643
def get_input_embeddings(
648644
self,
@@ -671,7 +667,6 @@ def forward(self,
671667
# NOTE: In v1, inputs_embeds is always generated at model runner, this
672668
# condition is for v0 compatibility.
673669
elif inputs_embeds is None:
674-
kwargs.update({"v0_path": True})
675670
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
676671

677672
inputs_embeds = self.get_input_embeddings(input_ids,

vllm/model_executor/models/internvl.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3636
from vllm.sequence import IntermediateTensors
3737
from vllm.transformers_utils.tokenizer import AnyTokenizer
38-
from vllm.utils import flatten_2d_lists
3938

4039
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4140
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
6665
A boolean mask indicating which image embeddings correspond
6766
to patch tokens.
6867
69-
Shape: `(batch_size, num_images, num_embeds)`
68+
Shape: `(batch_size * num_images, num_embeds)`
7069
"""
7170

7271

7372
class InternVLImageEmbeddingInputs(TypedDict):
7473
type: Literal["image_embeds"]
75-
data: NestedTensors
74+
data: Union[torch.Tensor, list[torch.Tensor]]
7675
"""
7776
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
7877
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
@@ -867,6 +866,7 @@ def _parse_and_validate_image_input(
867866

868867
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
869868
image_num_patches = flatten_bn(image_num_patches, concat=True)
869+
embed_is_patch = flatten_bn(embed_is_patch)
870870

871871
return InternVLImagePixelInputs(
872872
type="pixel_values",
@@ -881,7 +881,7 @@ def _parse_and_validate_image_input(
881881
def _process_image_input(
882882
self,
883883
image_input: InternVLImageInputs,
884-
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
884+
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
885885
if image_input["type"] == "image_embeds":
886886
return image_input["data"]
887887

@@ -921,15 +921,13 @@ def get_multimodal_embeddings(
921921

922922
image_features = self._process_image_input(image_input)
923923

924-
if (kwargs.get("v0_path", False)
925-
or image_input["type"] != "pixel_values"):
924+
if image_input["type"] != "pixel_values":
926925
return image_features
927926

928-
return flatten_2d_lists(
929-
scatter_patch_features(*args) for args in zip(
930-
image_features,
931-
image_input["embed_is_patch"],
932-
))
927+
return scatter_patch_features(
928+
image_features,
929+
image_input["embed_is_patch"],
930+
)
933931

934932
def get_input_embeddings(
935933
self,
@@ -964,7 +962,6 @@ def forward(
964962
# NOTE: In v1, inputs_embeds is always generated at model runner, this
965963
# condition is for v0 compatibility.
966964
elif inputs_embeds is None:
967-
kwargs.update({"v0_path": True})
968965
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
969966
inputs_embeds = self.get_input_embeddings(input_ids,
970967
vision_embeddings)

vllm/model_executor/models/llava.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
PromptReplacement, PromptUpdate)
3636
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3737
from vllm.sequence import IntermediateTensors
38-
from vllm.utils import flatten_2d_lists
3938

4039
from .clip import CLIPVisionModel
4140
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
7372
A boolean mask indicating which image embeddings correspond
7473
to patch tokens.
7574
76-
Shape: `(batch_size, num_images, num_embeds)`
75+
Shape: `(batch_size * num_images, num_embeds)`
7776
"""
7877

7978

@@ -618,6 +617,8 @@ def _parse_and_validate_image_input(
618617
raise ValueError("Incorrect type of embed_is_patch. "
619618
f"Got type: {type(embed_is_patch)}")
620619

620+
embed_is_patch = flatten_bn(embed_is_patch)
621+
621622
return PixtralHFImagePixelInputs(
622623
type="pixel_values_pixtral",
623624
pixel_values=flatten_bn(pixel_values),
@@ -713,18 +714,16 @@ def get_multimodal_embeddings(
713714
if image_input is None:
714715
return None
715716

716-
vision_embeddings = self._process_image_input(image_input)
717+
image_features = self._process_image_input(image_input)
717718

718-
if (kwargs.get("v0_path", False)
719-
or image_input["type"] != "pixel_values_pixtral"):
719+
if image_input["type"] != "pixel_values_pixtral":
720720
# The path is used for pixtral (V0 only) and llava (V0/V1)
721-
return vision_embeddings
721+
return image_features
722722

723-
return flatten_2d_lists(
724-
scatter_patch_features(*args) for args in zip(
725-
vision_embeddings,
726-
image_input["embed_is_patch"],
727-
))
723+
return scatter_patch_features(
724+
image_features,
725+
image_input["embed_is_patch"],
726+
)
728727

729728
def get_input_embeddings(
730729
self,
@@ -790,7 +789,6 @@ def forward(
790789
# NOTE: In v1, inputs_embeds is always generated at model runner, this
791790
# condition is for v0 compatibility.
792791
elif inputs_embeds is None:
793-
kwargs.update({"v0_path": True})
794792
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
795793
inputs_embeds = self.get_input_embeddings(input_ids,
796794
vision_embeddings)

vllm/model_executor/models/molmo.py

Lines changed: 32 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
PromptInsertion, PromptUpdate)
5050
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
5151
from vllm.sequence import IntermediateTensors
52-
from vllm.utils import flatten_2d_lists
5352

5453
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
5554
SupportsMultiModal, SupportsPP, SupportsQuant)
@@ -72,25 +71,25 @@
7271

7372
class MolmoImageInputs(TypedDict):
7473
images: Union[torch.Tensor, list[torch.Tensor]]
75-
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
74+
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
7675

7776
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
78-
"""Shape: `(batch_size, num_crops, num_patch)`"""
77+
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
7978

8079
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
8180
"""
8281
A boolean mask indicating which image features correspond
8382
to patch tokens.
8483
85-
Shape: `(batch_size, num_crops, num_patch)`
84+
Shape: `(batch_size * num_images, num_crops, num_patch)`
8685
"""
8786

8887
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
8988
"""
9089
A boolean mask indicating which image embeddings correspond
9190
to patch tokens.
9291
93-
Shape: `(batch_size, num_embeds)`
92+
Shape: `(batch_size * num_images, num_embeds)`
9493
"""
9594

9695
num_crops: Union[torch.Tensor, list[torch.Tensor]]
@@ -696,9 +695,10 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
696695
return image_features
697696

698697
def forward(
699-
self, images: torch.Tensor, image_masks: torch.Tensor
700-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
701-
698+
self,
699+
images: torch.Tensor,
700+
image_masks: torch.Tensor,
701+
) -> torch.Tensor:
702702
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
703703
batch_size, num_image = images.shape[:2]
704704
images = images.to(device=self.device, dtype=self.dtype)
@@ -1491,6 +1491,8 @@ def _parse_and_validate_image_input(
14911491
f"Got type: {type(img_patch_id)}")
14921492
self.img_patch_id = img_patch_id.flatten().unique().item()
14931493

1494+
embed_is_patch = flatten_bn(embed_is_patch)
1495+
14941496
return MolmoImageInputs(
14951497
images=images,
14961498
image_masks=image_masks,
@@ -1502,13 +1504,17 @@ def _parse_and_validate_image_input(
15021504
def _process_image_input(
15031505
self,
15041506
image_input: MolmoImageInputs,
1505-
) -> Union[torch.Tensor, list[torch.Tensor]]:
1506-
if isinstance(image_input["images"], list):
1507+
) -> list[torch.Tensor]:
1508+
images = image_input["images"]
1509+
image_masks = image_input["image_masks"]
1510+
feat_is_patch = image_input["feat_is_patch"]
1511+
num_crops = image_input["num_crops"]
1512+
1513+
if isinstance(images, list):
15071514
# Call the vision backbone on the whole batch at once
1508-
images_flat = flatten_bn(image_input["images"], concat=True)
1509-
image_masks_flat = (None if (image_masks :=
1510-
image_input["image_masks"]) is None
1511-
else flatten_bn(image_masks, concat=True))
1515+
images_flat = flatten_bn(images, concat=True)
1516+
image_masks_flat = (None if image_masks is None else flatten_bn(
1517+
image_masks, concat=True))
15121518

15131519
image_features_flat = self.vision_backbone(
15141520
images=images_flat.unsqueeze(0),
@@ -1517,63 +1523,19 @@ def _process_image_input(
15171523
).squeeze(0)
15181524

15191525
# Reconstruct the batch dimension
1520-
image_features = image_features_flat.split(
1521-
image_input["num_crops"].sum(-1).tolist())
1526+
num_crops_per_image = [nc.sum().item() for nc in num_crops]
1527+
image_features = image_features_flat.split(num_crops_per_image)
15221528
else:
15231529
image_features = self.vision_backbone(
1524-
images=image_input["images"],
1525-
image_masks=image_input["image_masks"],
1530+
images=images,
1531+
image_masks=image_masks,
15261532
)
15271533

1528-
return image_features
1529-
1530-
def _get_mm_embeds(
1531-
self,
1532-
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
1533-
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
1534-
num_crops: torch.Tensor, # Shape: (num_images,)
1535-
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
1536-
) -> tuple[torch.Tensor, ...]:
1537-
"""
1538-
Scatter the patch features into a contiguous tensor that corresponds
1539-
to the embedding tokens defined by the multimodal processor.
1540-
1541-
Note:
1542-
The original code only considers patch tokens as feature
1543-
tokens, but our processor considers all image-related tokens
1544-
as feature tokens because the feature tokens need to be
1545-
consecutive in `input_ids`.
1546-
1547-
Example:
1548-
A simplified example for one item in the batch:
1549-
1550-
.. code-block::
1551-
1552-
Embedding tokens (from HF processor):
1553-
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
1554-
1555-
embed_is_patch (from HF processor):
1556-
[ False True True False True True False False ]
1557-
1558-
Encoder outputs (from model):
1559-
[ p1 p2 0 p3 p4 0 ]
1560-
1561-
feat_is_patch (from HF processor):
1562-
[ True True False True True False ]
1563-
1564-
The resulting embedding tensor is:
1565-
[ nan p1 p2 nan p3 p4 nan nan ]
1566-
"""
1567-
num_crops_per_image = num_crops.tolist()
1568-
feats_per_image = features.split(num_crops_per_image)
1569-
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
1570-
1571-
features = torch.cat([
1534+
# Only the features corresponding to patch tokens are relevant
1535+
return [
15721536
feats[f_is_patch]
1573-
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
1574-
])
1575-
1576-
return scatter_patch_features(features, embed_is_patch)
1537+
for feats, f_is_patch in zip(image_features, feat_is_patch)
1538+
]
15771539

15781540
def get_multimodal_embeddings(
15791541
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@@ -1583,13 +1545,10 @@ def get_multimodal_embeddings(
15831545

15841546
image_features = self._process_image_input(image_input)
15851547

1586-
return flatten_2d_lists(
1587-
self._get_mm_embeds(*args) for args in zip(
1588-
image_features,
1589-
image_input["feat_is_patch"],
1590-
image_input["num_crops"],
1591-
image_input["embed_is_patch"],
1592-
))
1548+
return scatter_patch_features(
1549+
image_features,
1550+
image_input["embed_is_patch"],
1551+
)
15931552

15941553
def get_input_embeddings(
15951554
self,

0 commit comments

Comments
 (0)