Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
dfebf51
[Bugfix] Merge multimodal embeddings by `is_embed` mask instead of to…
DarkLight1337 Apr 8, 2025
437dacd
Rename
DarkLight1337 Apr 8, 2025
bbe7096
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Apr 9, 2025
57e9f03
Use #16007
DarkLight1337 Apr 9, 2025
d5c9555
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 27, 2025
e08deaa
Fix
DarkLight1337 Aug 27, 2025
302b2c5
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 27, 2025
6a1307f
Update
DarkLight1337 Aug 28, 2025
3a4740a
Fix
DarkLight1337 Aug 28, 2025
68c54d8
Draft
DarkLight1337 Aug 28, 2025
6ddc91e
Fix device
DarkLight1337 Aug 28, 2025
28cc8cb
Persistent buffer
DarkLight1337 Aug 28, 2025
c335908
Avoid unnecessary initialization
DarkLight1337 Aug 28, 2025
cbb70ea
Fix reset
DarkLight1337 Aug 28, 2025
76f2925
Update
DarkLight1337 Aug 28, 2025
b6e8775
Simplify
DarkLight1337 Aug 28, 2025
fee0d27
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 30, 2025
f71a40b
Use padded tokens
DarkLight1337 Sep 1, 2025
3af1bdb
Fix wrong device
DarkLight1337 Sep 1, 2025
003800e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 1, 2025
975569d
Debug
DarkLight1337 Sep 2, 2025
8d6b6c4
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 18, 2025
c001581
Fix?
DarkLight1337 Sep 18, 2025
9e4512c
Simplify the code
DarkLight1337 Sep 18, 2025
e002d44
Reduce diffs
DarkLight1337 Sep 18, 2025
1934f25
Avoid intermediate variable
DarkLight1337 Sep 18, 2025
573cb4b
Standardize input embeddings logic
DarkLight1337 Sep 18, 2025
fa5e688
Cleanup
DarkLight1337 Sep 18, 2025
0799fdb
Fix
DarkLight1337 Sep 18, 2025
7f58edc
Fix
DarkLight1337 Sep 18, 2025
1e9ec64
Comment out debug path
DarkLight1337 Sep 18, 2025
439b264
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 19, 2025
a9f7e84
fix tpu recompilations
NickLucche Sep 19, 2025
29e0ad5
Remove sanity check for code simplicity
DarkLight1337 Sep 19, 2025
9a6768e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 19, 2025
f6e7e62
Update interface for all MM models
DarkLight1337 Sep 19, 2025
74a4d5f
Avoid circular import
DarkLight1337 Sep 19, 2025
6d3a733
Fix `get_input_embeddings`
DarkLight1337 Sep 20, 2025
d30a4a6
Improve logging for unimpl methods
DarkLight1337 Sep 20, 2025
ad27e91
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 20, 2025
028aedf
More fixes
DarkLight1337 Sep 20, 2025
38058d1
Fix
DarkLight1337 Sep 20, 2025
a71a832
Fix
DarkLight1337 Sep 20, 2025
3d4495a
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 20, 2025
7d8f58d
Fix V0
DarkLight1337 Sep 20, 2025
e33a195
Rename `do_language_embed_multimodal -> handle_oov_mm_token`
DarkLight1337 Sep 21, 2025
ead536d
Update docstring
DarkLight1337 Sep 21, 2025
6db35c3
Add doc
DarkLight1337 Sep 21, 2025
7dc2675
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 22, 2025
d13fca8
Update DotsOCR
DarkLight1337 Sep 22, 2025
beb9df0
Fix wrong condition
DarkLight1337 Sep 22, 2025
8a6fb1b
fix qwen3-vl
ywang96 Sep 22, 2025
2eefc2d
Fix wrong condition
DarkLight1337 Sep 22, 2025
b79860e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 22, 2025
7769ec1
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 24, 2025
aa67033
Reduce diff
DarkLight1337 Sep 24, 2025
3656239
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 26, 2025
9260170
Simplify
DarkLight1337 Sep 26, 2025
2ac91b6
Fix doc
DarkLight1337 Sep 26, 2025
3033297
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 5 additions & 28 deletions docs/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,12 @@ Further update the model as follows:
!!! important
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.

- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.

??? code

```python
from .utils import merge_multimodal_embeddings

class YourModelForImage2Seq(nn.Module):
...

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:

# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)
!!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].

return inputs_embeds
```
You may override this method if additional logic is required for your model when merging embeddings.

- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.

Expand Down
7 changes: 6 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
else: # task == "auto"
pass
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError("The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}.")
f"{self.task!r}. Found: {debug_info}")
Comment on lines +512 to +519
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unrelated?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It helped me find out which models failed to implement get_input_embeddings, so I have decided to keep it to help OOT model developers in case they also forgot to implement this method.


self.runner = runner
self.convert = convert
Expand Down
25 changes: 6 additions & 19 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
merge_multimodal_embeddings)
is_pp_missing_parameter, maybe_prefix)


class AriaImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -605,19 +604,6 @@ def get_multimodal_embeddings(self,
multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -628,10 +614,11 @@ def forward(
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None

hidden_states = self.language_model(
Expand Down
27 changes: 6 additions & 21 deletions vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)


class AyaVisionImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -417,23 +416,6 @@ def get_multimodal_embeddings(self,

return self._process_image_input(image_input, **kwargs)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index,
)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -449,8 +431,11 @@ def forward(
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None

hidden_states = self.language_model.model(
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def __init__(
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -457,6 +460,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -588,6 +594,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
})

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
Expand Down Expand Up @@ -637,6 +646,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pooler.for_encode(pooler_config),
})

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ def __init__(self,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -673,6 +676,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loaded_params = loader.load_weights(weights)
return loaded_params

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down
22 changes: 6 additions & 16 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
Expand Down Expand Up @@ -631,19 +631,6 @@ def get_multimodal_embeddings(self,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -689,8 +676,11 @@ def forward(
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None

hidden_states = self.language_model.model(input_ids,
Expand Down
24 changes: 7 additions & 17 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)

logger = init_logger(__name__)

Expand Down Expand Up @@ -1002,20 +1002,6 @@ def get_multimodal_embeddings(self,
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:

inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -1032,8 +1018,12 @@ def forward(
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
image_token_id = self.model.vocabulary_mapping.image_token_id
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == image_token_id,
)
input_ids = None

hidden_states = self.model(input_ids,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def __init__(
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)

def compute_logits(
self,
hidden_states: torch.Tensor,
Expand Down
27 changes: 6 additions & 21 deletions vllm/model_executor/models/cohere2_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)


class Cohere2VisionImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -430,23 +429,6 @@ def get_multimodal_embeddings(self,

return self._process_image_input(image_input, **kwargs)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_id,
)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -462,8 +444,11 @@ def forward(
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None

hidden_states = self.language_model.model(
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/deepseek_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -205,6 +208,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.logits_processor = LogitsProcessor(config.vocab_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -142,6 +145,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(
prefix, "model"))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading