Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
Expand Down Expand Up @@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""


Expand Down Expand Up @@ -593,6 +592,7 @@ def _parse_and_validate_image_input(

pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)

return Gemma3ImagePixelInputs(
type="pixel_values",
Expand Down Expand Up @@ -635,14 +635,10 @@ def get_multimodal_embeddings(

image_features = self._process_image_input(image_input)

if kwargs.get("v0_path", False):
return image_features

return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)

def get_input_embeddings(
self,
Expand Down Expand Up @@ -671,7 +667,6 @@ def forward(self,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)

inputs_embeds = self.get_input_embeddings(input_ids,
Expand Down
21 changes: 9 additions & 12 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
Expand Down Expand Up @@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: NestedTensors
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
Expand Down Expand Up @@ -867,6 +866,7 @@ def _parse_and_validate_image_input(

pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)

return InternVLImagePixelInputs(
type="pixel_values",
Expand All @@ -881,7 +881,7 @@ def _parse_and_validate_image_input(
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]

Expand Down Expand Up @@ -921,15 +921,13 @@ def get_multimodal_embeddings(

image_features = self._process_image_input(image_input)

if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values"):
if image_input["type"] != "pixel_values":
return image_features

return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)

def get_input_embeddings(
self,
Expand Down Expand Up @@ -964,7 +962,6 @@ def forward(
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
Expand Down
22 changes: 10 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""


Expand Down Expand Up @@ -618,6 +617,8 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

embed_is_patch = flatten_bn(embed_is_patch)

return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
Expand Down Expand Up @@ -713,18 +714,16 @@ def get_multimodal_embeddings(
if image_input is None:
return None

vision_embeddings = self._process_image_input(image_input)
image_features = self._process_image_input(image_input)

if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values_pixtral"):
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
return image_features

return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)

def get_input_embeddings(
self,
Expand Down Expand Up @@ -790,7 +789,6 @@ def forward(
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
Expand Down
105 changes: 32 additions & 73 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
Expand All @@ -72,25 +71,25 @@

class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""

image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""

feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""

num_crops: Union[torch.Tensor, list[torch.Tensor]]
Expand Down Expand Up @@ -696,9 +695,10 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
return image_features

def forward(
self, images: torch.Tensor, image_masks: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

self,
images: torch.Tensor,
image_masks: torch.Tensor,
) -> torch.Tensor:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype)
Expand Down Expand Up @@ -1491,6 +1491,8 @@ def _parse_and_validate_image_input(
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()

embed_is_patch = flatten_bn(embed_is_patch)

return MolmoImageInputs(
images=images,
image_masks=image_masks,
Expand All @@ -1502,13 +1504,17 @@ def _parse_and_validate_image_input(
def _process_image_input(
self,
image_input: MolmoImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(image_input["images"], list):
) -> list[torch.Tensor]:
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]

if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True)
image_masks_flat = (None if (image_masks :=
image_input["image_masks"]) is None
else flatten_bn(image_masks, concat=True))
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))

image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
Expand All @@ -1517,63 +1523,19 @@ def _process_image_input(
).squeeze(0)

# Reconstruct the batch dimension
image_features = image_features_flat.split(
image_input["num_crops"].sum(-1).tolist())
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=image_input["images"],
image_masks=image_input["image_masks"],
images=images,
image_masks=image_masks,
)

return image_features

def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.

Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.

Example:
A simplified example for one item in the batch:

.. code-block::

Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]

embed_is_patch (from HF processor):
[ False True True False True True False False ]

Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]

feat_is_patch (from HF processor):
[ True True False True True False ]

The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image = num_crops.tolist()
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

features = torch.cat([
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
])

return scatter_patch_features(features, embed_is_patch)
for feats, f_is_patch in zip(image_features, feat_is_patch)
]

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
Expand All @@ -1583,13 +1545,10 @@ def get_multimodal_embeddings(

image_features = self._process_image_input(image_input)

return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)

def get_input_embeddings(
self,
Expand Down
Loading