diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 87d34d207cde..1d72fe97b966 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -66,35 +66,12 @@ Further update the model as follows: !!! important The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. -- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. - - ??? code - - ```python - from .utils import merge_multimodal_embeddings - - class YourModelForImage2Seq(nn.Module): - ... - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index) +!!! note + By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in + [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. + This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. - return inputs_embeds - ``` + You may override this method if additional logic is required for your model when merging embeddings. - Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. diff --git a/vllm/config/model.py b/vllm/config/model.py index b2b68abd2c1d..3fb448ebbf36 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -509,9 +509,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType: else: # task == "auto" pass else: + debug_info = { + "architectures": architectures, + "is_generative_model": is_generative_model, + "is_pooling_model": is_pooling_model, + } raise AssertionError("The model should be a generative or " "pooling model when task is set to " - f"{self.task!r}.") + f"{self.task!r}. Found: {debug_info}") self.runner = runner self.convert = convert diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 35c1adbdd00b..6cef5e134a4b 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -38,8 +38,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, maybe_prefix, - merge_multimodal_embeddings) + is_pp_missing_parameter, maybe_prefix) class AriaImagePixelInputs(TensorSchema): @@ -605,19 +604,6 @@ def get_multimodal_embeddings(self, multimodal_embeddings = self._process_image_input(image_input) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -628,10 +614,11 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6fd8c2fb5c56..eab996e9ba22 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -33,8 +33,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class AyaVisionImagePixelInputs(TensorSchema): @@ -417,23 +416,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +431,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index ee32587f6b1b..c984845204c4 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -348,6 +348,9 @@ def __init__( self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -457,6 +460,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.pooler = self._build_pooler(pooler_config) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -588,6 +594,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) @@ -637,6 +646,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): Pooler.for_encode(pooler_config), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index bfc1408ddf88..4e1eba32d259 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -426,6 +426,9 @@ def __init__(self, prefix=f"{prefix}.encoder") self.pooler = BertPooler(self.config) if add_pooling_layer else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -673,6 +676,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loaded_params = loader.load_weights(weights) return loaded_params + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.new.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b7455fba62c0..4d1850d07b28 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -27,7 +27,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -631,19 +631,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - _IMAGE_TOKEN_ID) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -689,8 +676,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 79d648d749c6..f9740adb151b 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -44,7 +44,7 @@ SupportsQuant) from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) logger = init_logger(__name__) @@ -1002,20 +1002,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.model.vocabulary_mapping.image_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1032,8 +1018,12 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + image_token_id = self.model.vocabulary_mapping.image_token_id + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == image_token_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 879508400222..c182201fe256 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -433,6 +433,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 6d67eb68d51a..99edcba4d874 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -37,8 +37,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class Cohere2VisionImagePixelInputs(TensorSchema): @@ -430,23 +429,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,8 +444,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index ed7e7614800f..c42a66d86912 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -66,6 +66,9 @@ def __init__( self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -205,6 +208,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 92f311ab465b..a4623ff13cec 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -101,6 +101,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -142,6 +145,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c8ed759d2e97..b98008c83bdc 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) # The image token id may be various _IMAGE_TOKEN = "" @@ -346,7 +345,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config tokenizer = cached_tokenizer_from_config(model_config) - self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] self.vision = self._init_vision_module(self.vision_config, quant_config, @@ -605,19 +604,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -632,8 +618,11 @@ def forward(self, # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2db350c892ae..4845f19bcbc4 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -34,8 +34,7 @@ Qwen2VLProcessingInfo) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + maybe_prefix) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict @@ -796,33 +795,17 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -830,17 +813,14 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None and kwargs.get("pixel_values") is not None: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - inputs_embeds = None - else: - assert input_ids is not None - inputs_embeds = self.get_multimodal_embeddings( - input_ids, - image_input=image_input, - ) - input_ids = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) + input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 74b358034ef3..a73ec4f88ffe 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -60,8 +60,7 @@ from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -1467,18 +1466,24 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is None: - return inputs_embeds - - self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, - multimodal_embeddings, - [self.config.im_patch_id]) - return inputs_embeds + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: + self._set_visual_token_mask(input_ids) + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 288fbe736c32..3b24bf2f1ef8 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -116,6 +116,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -160,6 +163,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 53e9e6fe6e46..b99fe33a1dcc 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -42,8 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -342,22 +341,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - _IMAGE_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -373,8 +356,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 0630ee07c347..be75e36fe23b 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -37,8 +37,7 @@ SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -588,22 +587,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -618,8 +601,11 @@ def forward(self, elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) if (vision_embeddings is not None) and len(vision_embeddings) != 0: kwargs = self.prepare_attn_masks( input_ids, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 2acdba54a257..b23437a08e5a 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -632,8 +632,10 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: @@ -645,15 +647,16 @@ def get_input_embeddings( self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( per_layer_inputs) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - # NOTE: this order of processing mm items is important - [self.config.image_token_id, self.config.audio_token_id]) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b088e0c0dd24..dbb5431ae491 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1552,23 +1552,6 @@ def get_multimodal_embeddings( multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0 - and all(embed.numel() > 0 for embed in multimodal_embeddings)): - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index c572978e6220..826d541e571b 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -132,6 +132,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -173,6 +176,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index bf33575859ae..ace9c05daf15 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -43,7 +43,7 @@ from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn, isin_list class GLMVImagePixelInputs(TensorSchema): @@ -607,28 +607,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=[ - self.config.boi_token_id, - self.config.pad_token_id, - self.config.eoi_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -644,8 +622,15 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, [ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ]), + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index a5849184339b..8a02da58ea0b 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -52,8 +52,7 @@ from .blip2 import Blip2QFormerModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, embed_multimodal, - init_vllm_registered_model, maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix ### Audio Input @@ -720,6 +719,9 @@ def _process_audio_input( # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object, @@ -728,7 +730,7 @@ def get_multimodal_embeddings( audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] - return None + audio_features = self._process_audio_input(audio_input) return audio_features @@ -736,19 +738,21 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - """Compute the merged LLM / audio embeddings.""" - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.audio_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -765,7 +769,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: audio_embeds = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeds, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None model_output = self.language_model(input_ids, positions, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 8a23a6b45bc7..d28c97116790 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -989,6 +989,9 @@ def update_physical_experts_metadata( moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 4d39ff9ae79e..f851688bf7ba 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -45,8 +45,8 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list, + maybe_prefix) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -691,7 +691,7 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: Unpack[HCXVisionMultimodalInputs], - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings: multimodal_embeddings = list() if kwargs.get("pixel_values_images") is not None: @@ -736,26 +736,6 @@ def get_multimodal_embeddings( multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - placeholder_token_id=[ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +751,13 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=isin_list( + input_ids, + [self.config.image_token_id, self.config.video_token_id]), + ) input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 79e130119ae8..3334ee224253 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -52,8 +52,7 @@ # yapf: enable from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -539,10 +538,7 @@ def image_pixels_to_features( return image_hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.text_model.get_input_embeddings(input_ids) def forward( @@ -695,22 +691,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -726,8 +706,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.model.text_model(input_ids, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index f13e590cd243..d40df9b43dd4 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, MutableSequence -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - Union, overload, runtime_checkable) +from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional, + Protocol, Union, overload, runtime_checkable) import numpy as np import torch @@ -20,7 +20,7 @@ QuantizationConfig) from vllm.utils import supports_kw -from .interfaces_base import is_pooling_model +from .interfaces_base import VllmModel, is_pooling_model if TYPE_CHECKING: from vllm.config import VllmConfig @@ -90,7 +90,7 @@ def get_multimodal_embeddings(self, """ ... - def get_language_model(self) -> torch.nn.Module: + def get_language_model(self) -> VllmModel: """ Returns the underlying language model used for text generation. @@ -102,17 +102,84 @@ def get_language_model(self) -> torch.nn.Module: """ ... + @overload + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + ... + + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: MultiModalEmbeddings, + *, + is_multimodal: torch.Tensor, + handle_oov_mm_token: bool = False, + ) -> Tensor: + ... + + def _get_text_embeddings( + self, + input_ids: Tensor, + get_input_embeddings: Callable[[Tensor], Tensor], + *, + is_multimodal: Optional[Tensor], + handle_oov_mm_token: bool, + ) -> Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) + def get_input_embeddings( self, input_ids: Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[Tensor] = None, + handle_oov_mm_token: bool = False, ) -> Tensor: """ - Returns the input embeddings merged from the text embeddings from - input_ids and the multimodal embeddings generated from multimodal - kwargs. + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. """ - ... + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) @runtime_checkable diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 8fdf70e35a2b..84146db0943c 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -41,6 +41,13 @@ def __init__( ) -> None: ... + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """Apply token embeddings to `input_ids`.""" + ... + def forward( self, input_ids: torch.Tensor, @@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: return supports_kw(model_init, "vllm_config") +def _check_vllm_model_get_input_embeddings( + model: Union[type[object], object]) -> bool: + model_get_input_embeddings = getattr(model, "get_input_embeddings", None) + if not callable(model_get_input_embeddings): + logger.warning( + "The model (%s) is missing the `get_input_embeddings` method.", + model, + ) + return False + + return True + + def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): @@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: - return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + return (_check_vllm_model_init(model) + and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_forward(model)) @runtime_checkable diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 197d629b906f..545dad1a96f5 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -40,8 +40,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, isin_list, maybe_prefix) class InternS1MultiModalProjector(nn.Module): @@ -767,24 +766,24 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -802,9 +801,17 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f4004e518e3b..78aac8541434 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -43,7 +43,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + isin_list, maybe_prefix) IMG_START = '' IMG_END = '' @@ -1339,24 +1339,24 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1374,9 +1374,17 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 3b6fdba22512..62a71b7b1fa8 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1450,24 +1450,6 @@ def get_multimodal_embeddings( multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 503627865c4a..db032736f914 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -66,7 +66,6 @@ from vllm.model_executor.models.interfaces import (SupportsMultiModal, SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel -from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -424,26 +423,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.media_placeholder_token_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,14 +441,12 @@ def forward( if image_input is None: inputs_embeds = None else: - inputs_embeds = self.get_input_embeddings(input_ids) image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( + inputs_embeds = self.get_input_embeddings( input_ids, - inputs_embeds, image_embeds, - placeholder_token_id=self.config. - media_placeholder_token_id, + is_multimodal=input_ids == + self.config.media_placeholder_token_id, ) input_ids = None diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 53c36e4e52d8..f9def222a1ec 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -522,6 +522,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index a203af53205c..235275c0940a 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,9 +37,9 @@ from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, maybe_prefix logger = init_logger(__name__) @@ -79,10 +79,7 @@ def __init__( self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -194,6 +191,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_language_model(self) -> torch.nn.Module: + return self.model + + get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore + def forward( self, input_ids: torch.Tensor, @@ -220,20 +222,3 @@ def transform(inputs): skip_prefixes=(["lm_head."]), ) loader.load_weights(map(transform, weights)) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 2ff2d54a83aa..d6e6fd3fcfe9 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -73,6 +73,9 @@ def __init__( self.config.hidden_size, bias=False) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -149,6 +152,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 55b6ae6ee0e9..34b8ea0ca536 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -144,10 +143,7 @@ def __init__( eps=self.config.rms_norm_eps, ) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -239,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -302,11 +301,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - return inputs_embeds diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4d8ed95b6cc8..6f3cfd88aee2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -41,8 +41,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -676,22 +675,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -744,8 +727,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c9133fde1455..e132389c4f06 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -25,8 +25,8 @@ LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, - flatten_bn, init_vllm_registered_model, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TensorSchema): @@ -474,19 +474,21 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) - - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -549,8 +551,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 610fb188d57d..2642d8c77cf3 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -30,8 +30,7 @@ from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -415,19 +414,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_video_pixels(video_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.video_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +435,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.video_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index cee9ddaf94cc..906858f4e2f4 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -850,19 +850,6 @@ def get_multimodal_embeddings(self, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_index, self.config.video_token_index]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 82648ba668ca..0bf04e0e7e2f 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -54,8 +54,7 @@ from vllm.transformers_utils.configs.midashenglm import DashengConfig from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix _Tuple2 = Union[int, tuple[int, int], Sequence[int]] @@ -744,21 +743,6 @@ def get_multimodal_embeddings(self, return [] return self._process_audio_input(audio_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.decoder.get_input_embeddings(input_ids) - if multimodal_embeddings and len(multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.audio_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +755,11 @@ def forward( inputs_embeds = None elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_id, + ) input_ids = None return self.decoder.model(input_ids, diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index b4abe458e477..9c1e36094c4a 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -117,6 +117,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -158,6 +161,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a17c4f004d75..bffc9a0c125e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -71,8 +71,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -1144,23 +1143,6 @@ def get_multimodal_embeddings(self, return self._process_multimodal_inputs(modalities) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert len(self.mm_token_ids) > 0 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - list(self.mm_token_ids), - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1178,8 +1160,11 @@ def forward( elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, list(self.mm_token_ids)), + ) input_ids = None hidden_states = self.llm.model( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index cc9a959f6331..a92890c9f7b5 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -592,10 +592,7 @@ def _clear_prefill_cache(self, attn_metadata, dtype=torch.long) minimax_cache_tensors[:, slots_tensor, ...] = 0 - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward(self, @@ -687,10 +684,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( batch_size) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward(self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d81ac8c704e7..d41b9d3f14fe 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -28,7 +28,7 @@ from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) class MiniMaxVL01ImagePixelInputs(TensorSchema): @@ -218,22 +218,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -403,8 +387,11 @@ def forward( inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index ba6da4403ae1..31571ce962d1 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -38,8 +38,7 @@ SupportsMultiModal, SupportsPP) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -524,22 +523,6 @@ def get_multimodal_embeddings(self, return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -592,8 +575,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 79e315f79489..3af5267928cd 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -56,8 +56,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -813,24 +812,6 @@ def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -846,8 +827,11 @@ def forward( # this condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None return self.language_model(input_ids, positions, intermediate_tensors, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 1d5da3139de9..e4a51b369737 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -43,6 +43,9 @@ def __init__(self, config: ModernBertConfig): eps=config.layer_norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,6 +223,9 @@ def __init__( eps=config.norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) @@ -333,6 +339,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 201bf83cac58..054caee9e8a4 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -58,7 +58,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -819,10 +819,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1481,24 +1478,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_patch_id is not None - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_patch_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.LongTensor, @@ -1515,8 +1494,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_patch_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2b68d40cf2c6..505806a15c89 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -35,8 +35,7 @@ from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import (flatten_bn, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + isin_list, maybe_prefix) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, @@ -1096,8 +1095,8 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return modalities - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: # Validate the multimodal input keyword arguments modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if modalities is None: @@ -1121,30 +1120,6 @@ def get_multimodal_embeddings( return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -1163,9 +1138,17 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 3abbff8c717d..2627a262e958 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -38,7 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -576,20 +576,24 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [self.img_context_token_id] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -608,8 +612,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 2e0b1fb2a13f..e7e30ee8df0f 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -295,6 +295,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], self.config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,6 +411,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index bd525b6780e0..8503d3f71d1c 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -48,7 +48,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "" @@ -501,19 +500,6 @@ def get_multimodal_embeddings(self, return image_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_pad_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -529,8 +515,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have an inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f18e38ce154d..2ecc7bff07e0 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -585,17 +585,6 @@ def get_multimodal_embeddings(self, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - tmp = torch.concat(multimodal_embeddings, dim=0) - inputs_embeds[input_ids == self.image_pad_token_id] = tmp - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -612,8 +601,11 @@ def forward( elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have a inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index aef510230461..f07f444819f4 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -26,8 +26,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info logger = init_logger(__name__) @@ -362,19 +361,6 @@ def get_multimodal_embeddings(self, vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -388,8 +374,11 @@ def forward(self, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a2b201fe4228..ea34c8d92f13 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -51,9 +51,9 @@ from .clip import CLIPVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, + _merge_multimodal_embeddings, flatten_bn, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -643,14 +643,31 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds + inputs_embeds = self._get_text_embeddings( + input_ids, + self.embed_tokens, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) def forward(self, input_ids: torch.Tensor, @@ -666,8 +683,11 @@ def forward(self, # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=self.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index d2a3a8cc0496..ed9376633be4 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -1341,12 +1341,12 @@ def _process_image_input( image_attention_mask) return image_embeds - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1370,18 +1370,6 @@ def get_multimodal_embeddings( return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 47b5ad55ab2d..15b09c7ae2bc 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1151,7 +1151,6 @@ def get_multimodal_embeddings(self, modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1175,19 +1174,6 @@ def get_multimodal_embeddings(self, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.embed_tokens(input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 7b197844c8b6..2c04b6f0f4f9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -50,8 +50,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs try: @@ -433,22 +432,6 @@ def get_multimodal_embeddings(self, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.vision_args.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -465,8 +448,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.vision_args.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5f27230c913b..bfa398ee43b5 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -865,24 +865,26 @@ def get_multimodal_embeddings(self, multimodal_embeddings += audio_embeddings return multimodal_embeddings + # TODO (ywang96): support overlapping modality embeddings so that + # `use_audio_in_video` will work on V1. def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - - # TODO (ywang96): support overlapping modality embeddings so that - # `use_audio_in_video` will work on V1. - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - self.config.image_token_index, - self.config.video_token_index, - self.config.audio_token_index - ]) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def get_multimodal_embeddings_v0( self, **kwargs: object) -> Optional[NestedTensors]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index adb21373056c..5b092b42205f 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1365,19 +1365,6 @@ def get_multimodal_embeddings(self, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 762ab42e5929..9dfa29eef5ce 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -49,8 +49,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix # # === Audio Inputs === # @@ -438,19 +437,6 @@ def get_multimodal_embeddings(self, masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -467,8 +453,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d4e195246bf1..8192c3ce05dd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1459,19 +1459,6 @@ def get_multimodal_embeddings(self, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f1aeb99a4d37..5d0b66f91ace 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,7 +79,8 @@ from .qwen2_vl import Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - maybe_prefix, merge_multimodal_embeddings) + _merge_multimodal_embeddings, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -1324,17 +1325,22 @@ def get_multimodal_embeddings( return multimodal_embeddings def _compute_deepstack_embeds( - self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: - visual_lens = [ - x.shape[0] if isinstance(x, torch.Tensor) else len(x) - for x in multimodal_embeddings - ] + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) - multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 - multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], - dim=-1) + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, @@ -1346,39 +1352,62 @@ def _compute_deepstack_embeds( inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)) - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - multimodal_embeddings_multiscale, - placeholder_token_id=[ - self.config.image_token_id, self.config.video_token_id - ], + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + return deepstack_input_embeds, multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - deepstack_input_embeds = None - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - if self.use_deepstack: - deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 - input_ids, inputs_embeds, multimodal_embeddings) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") if self.use_deepstack: - if deepstack_input_embeds is None: - deepstack_input_embeds = torch.zeros_like( - inputs_embeds).unsqueeze(0).repeat( - self.deepstack_num_level, 1, 1).contiguous() + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze( + 0).repeat(self.deepstack_num_level, 1, 1).contiguous() self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 90200f319464..dc11b60604a9 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -45,7 +45,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .qwen import QWenBaseModel, QWenModel -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn class QwenImagePixelInputs(TensorSchema): @@ -756,21 +756,6 @@ def get_multimodal_embeddings(self, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.transformer.visual.image_pad_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -786,8 +771,12 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == + self.transformer.visual.image_pad_id, + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba405be41687..53e698c4fa80 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -218,6 +218,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.roberta.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 893ce4497c31..f9a107c06085 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -38,7 +38,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -842,19 +842,24 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_context_token_id is not None + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_context_token_id, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -873,8 +878,11 @@ def forward( # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index c774171b9dcd..c5b82b0ca4a0 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -483,6 +483,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 0cce0c78f8dc..0fe723d59483 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -395,6 +395,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5f6ad5885043..ad295ef44732 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -37,8 +37,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import run_dp_sharded_vision_model @@ -996,10 +995,13 @@ def _process_image_input( 1 else cur_feature[0]) return merged_image_features - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -1007,24 +1009,21 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - if multimodal_embeddings is None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - else: - is_text = input_ids != self.config.image_token_id - text_ids = input_ids[is_text] - text_embeds = self.language_model.model.get_input_embeddings( - text_ids) - inputs_embeds = torch.empty(input_ids.shape[0], - text_embeds.shape[-1], - dtype=text_embeds.dtype, - device=text_embeds.device) - inputs_embeds[is_text] = text_embeds - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_id) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1038,10 +1037,11 @@ def forward( inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 3660efdc079a..1145bea41480 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -40,7 +40,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) from .vision import VisionEncoderInfo, get_vision_encoder_info @@ -589,22 +589,6 @@ def get_multimodal_embeddings(self, return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -617,8 +601,11 @@ def forward( inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index b9dfa8e9b6f5..938b02e3e04b 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -233,6 +233,9 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # We do not really use any input tokens and therefore no embeddings # to be calculated. However, due to the mandatory token ids in diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 19dd242f16eb..3d7b06633f34 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -52,8 +52,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, - SupportsQuant) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, flatten_bn, make_empty_intermediate_tensors_factory, maybe_prefix) @@ -797,6 +797,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings()(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, @@ -873,13 +876,19 @@ def forward( multimodal_embeds = self.get_multimodal_embeddings(**kwargs) if multimodal_embeds is not None: inputs_embeds = self.get_input_embeddings( - input_ids, multimodal_embeds) + input_ids, + multimodal_embeds, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None model_output = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings(self, **kwargs): pixel_values = kwargs.pop("pixel_values", None) pixel_values = pixel_values if pixel_values is not None else kwargs.pop( @@ -934,15 +943,42 @@ def get_multimodal_embeddings(self, **kwargs): def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings=None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - mask = (input_ids == self.config.image_token_id) - mask = mask.unsqueeze(-1).expand_as(inputs_embeds) - multimodal_embeddings = torch.cat(multimodal_embeddings) - - inputs_embeds = inputs_embeds.masked_scatter( - mask, multimodal_embeddings) - return inputs_embeds + """ + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. + """ + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.model.get_input_embeddings(), + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 12ae9487ad9d..77e886c22e63 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -33,8 +33,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 @@ -555,19 +554,21 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - # The audio token index is not included in the embedding table - # We need to remove it before embedding lookup - safe_input_ids = input_ids.clone() - safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0 - inputs_embeds = self.language_model.get_input_embeddings( - safe_input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, @@ -601,8 +602,11 @@ def forward(self, elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None language_model = self.language_model diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 51cd41c864f0..7b3f20c6b28a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, - is_multimodal: torch.Tensor, multimodal_embeddings: NestedTensors, + is_multimodal: torch.Tensor, ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -402,61 +402,35 @@ def _merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - flattened = _flatten_embeddings(multimodal_embeddings) + if len(multimodal_embeddings) == 0: + return inputs_embeds + + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + input_dtype = inputs_embeds.dtype + try: - # This is equivalent to: inputs_embeds[is_multimodal] = flattened. + # For debugging + # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) + + # NOTE: This can avoid D2H sync (#22105), but fails to + # raise an error if is_multimodal.sum() < len(mm_embeds_flat) inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - flattened.to(dtype=inputs_embeds.dtype)) + mm_embeds_flat.to(dtype=input_dtype)) except RuntimeError as e: + num_actual_tokens = len(mm_embeds_flat) num_expected_tokens = is_multimodal.sum().item() - assert isinstance(num_expected_tokens, int) - if flattened.shape[0] != num_expected_tokens: + if num_actual_tokens != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " + f"Attempted to assign {expr} = {num_actual_tokens} " f"multimodal tokens to {num_expected_tokens} placeholders" ) from e - else: - raise ValueError("Error during masked scatter operation") from e - - return inputs_embeds - - -def embed_multimodal( - input_ids: torch.Tensor, - multimodal_token_id: int, - get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: NestedTensors, -) -> torch.Tensor: - """ - Embed token IDs and multimodal inputs and combine their embeddings. - - ``multimodal_token_id`` is used to determine whether a token ID should - be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. - Compared to ``merge_multimodal_embeddings`, this avoids running - ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` - which causes issues when the placeholder token ID exceeds the - vocabulary size of the language model. - """ - is_multimodal = input_ids == multimodal_token_id - is_text = ~is_multimodal - - text_embeds = get_text_embeds(input_ids[is_text]) - merged_embeds = torch.empty( - (input_ids.shape[0], text_embeds.shape[1]), - dtype=text_embeds.dtype, - device=text_embeds.device, - ) + raise ValueError("Error during masked scatter operation") from e - merged_embeds[is_text] = text_embeds - - return _merge_multimodal_embeddings( - merged_embeds, - is_multimodal, - multimodal_embeds, - ) + return inputs_embeds def merge_multimodal_embeddings( @@ -491,23 +465,29 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor( - placeholder_token_id, - pin_memory=is_pin_memory_available()).to(device=input_ids.device, - non_blocking=True) - return _merge_multimodal_embeddings( - inputs_embeds, - torch.isin(input_ids, placeholder_token_id), - multimodal_embeddings, - ) + is_multimodal = isin_list(input_ids, placeholder_token_id) + else: + is_multimodal = (input_ids == placeholder_token_id) return _merge_multimodal_embeddings( inputs_embeds, - (input_ids == placeholder_token_id), - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, ) +def isin_list( + elements: torch.Tensor, + test_elements_list: list[int], +) -> torch.Tensor: + test_elements = torch.tensor( + test_elements_list, + pin_memory=is_pin_memory_available(), + ).to(device=elements.device, non_blocking=True) + + return torch.isin(elements, test_elements) + + class LayerFn(Protocol): def __call__(self, prefix: str) -> torch.nn.Module: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index b33e8d09c4be..f93e7ccfd06f 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -45,10 +45,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsTranscription) -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix logger = init_logger(__name__) @@ -376,9 +374,14 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + audio_encoder = self.tokenizer.instruct.audio_encoder + audio_tok_id = audio_encoder.audio_token audio_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - audio_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeddings, + is_multimodal=input_ids == audio_tok_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, @@ -421,20 +424,6 @@ def get_multimodal_embeddings( return audio_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - audio_encoder = self.tokenizer.instruct.audio_encoder - audio_tok_id = audio_encoder.audio_token - - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id) - return inputs_embeds - def _parse_and_validate_audio_arrays( self, **kwargs: object) -> Union[list[torch.Tensor], None]: audio_arrays = kwargs.pop("audio_arrays", None) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index de3e4f0592a6..7beeeddf988f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -579,10 +579,7 @@ def forward( hidden_states = self.layer_norm(hidden_states) return hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -916,7 +913,10 @@ def get_multimodal_embeddings(self, def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 51e54e0dc337..1b5bafb9ca1b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -18,6 +18,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -64,8 +65,10 @@ def __init__( # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() - self.is_multimodal_model = vllm_config.model_config \ - .is_multimodal_model + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + vllm_config.model_config) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None @@ -175,7 +178,8 @@ def propose( last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embeds: Optional[list[torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -219,18 +223,21 @@ def propose( # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states - if self.is_multimodal_model: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( - input_ids, - multimodal_embeddings=mm_embeds or None, + + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) - self.inputs_embeds[:num_tokens] = inputs_embeds - inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] else: - inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -372,14 +379,15 @@ def propose( self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states - if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings(input_ids) - self.inputs_embeds[:batch_size] = inputs_embeds - inputs_embeds = self.inputs_embeds[:input_batch_size] + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = \ + self.model.get_input_embeddings(input_ids) + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] else: - inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None # Run the model. with set_forward_context(per_layer_attn_metadata, @@ -849,7 +857,7 @@ def load_model(self, target_model: nn.Module) -> None: self.attn_layer_names = list(draft_attn_layer_names) - if self.is_multimodal_model: + if self.supports_mm_inputs: # Even if the target model is multimodal, we can also use # text-only draft models try: @@ -861,7 +869,7 @@ def load_model(self, target_model: nn.Module) -> None: logger.warning( "Draft model does not support multimodal inputs, " "falling back to text-only mode") - self.is_multimodal_model = False + self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality @@ -933,7 +941,7 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22a177dd7cc7..1bae0d4ce4d1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -368,6 +368,11 @@ def __init__( self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -1627,9 +1632,16 @@ def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + mm_embeds = list[torch.Tensor]() + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 should_sync_mrope_positions = False - mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] @@ -1638,6 +1650,7 @@ def _gather_mm_embeddings( req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens + for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position start_pos = pos_info.offset @@ -1670,6 +1683,10 @@ def _gather_mm_embeddings( if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True if is_embed is None else is_embed + mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], is_embed=is_embed, @@ -1677,6 +1694,7 @@ def _gather_mm_embeddings( mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: + assert req_state.mrope_positions is not None should_sync_mrope_positions = True mm_embeds_req, new_mrope_positions, new_delta = ( self.model.recompute_mrope_positions( @@ -1685,18 +1703,19 @@ def _gather_mm_embeddings( mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, )) - assert req_state.mrope_positions is not None req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) - self.mrope_positions.copy_to_gpu( - scheduler_output.total_num_scheduled_tokens) + self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) - return mm_embeds + return mm_embeds, is_mm_embed def _extract_encoder_inputs( self, @@ -1990,14 +2009,16 @@ def _preprocess( and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids.gpu[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) # TODO(woosuk): Avoid the copy. Optimize. @@ -2586,10 +2607,14 @@ def propose_draft_token_ids( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - mm_embeds = None + if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -2599,8 +2624,9 @@ def propose_draft_token_ids( last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, + mm_embed_inputs=mm_embed_inputs, ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a330f50875a8..2405f978ca73 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -263,6 +263,13 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context @@ -879,13 +886,22 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens) + + is_mm_embed = self.is_mm_embed_cpu + is_mm_embed[:padded_total_num_scheduled_tokens] = False + mm_embeds = list[torch.Tensor]() + req_start_idx = 0 + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens + # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -906,26 +922,53 @@ def _gather_mm_embeddings( # The encoder output is already processed and stored # in the decoder's KV cache. continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." + assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[mm_hash] + + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True + + # Only whole mm items are processed mm_embeds.append(encoder_output) - return mm_embeds - def _get_model_inputs(self, input_ids: torch.Tensor, - mm_embeds: list[torch.Tensor]): + req_start_idx += num_scheduled_tokens + + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \ + .to(self.device) + + return mm_embeds, is_mm_embed + + def _get_model_inputs( + self, + input_ids: torch.Tensor, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]], + ): if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds = self.model.get_input_embeddings( - input_ids=input_ids, + input_ids, multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) + return None, inputs_embeds else: # For text-only models, we use token ids as input. @@ -953,9 +996,10 @@ def execute_model( if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) else: - mm_embeds = [] + mm_embed_inputs = None + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. @@ -972,7 +1016,7 @@ def execute_model( attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ end_index = self._prepare_inputs(scheduler_output, start_index) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) + self.input_ids, mm_embed_inputs) torch_xla.sync(wait=False) # Run the decoder with set_forward_context( @@ -1325,9 +1369,15 @@ def _precompile_mm_encoder(self) -> None: hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) + + mm_mask = torch.tensor([False] * num_tokens) + mm_mask[:items_size] = True + mm_mask = mm_mask.to(self.device) # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs(placeholders_ids, - [mm_embeds]) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=([mm_embeds], mm_mask), + ) assert a is None torch_xla.sync(wait=False) @@ -1338,7 +1388,10 @@ def _precompile_mm_encoder(self) -> None: dtype=torch.int32, device="cpu") placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs(placeholders_ids, []) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=None, + ) assert a is None torch_xla.sync(wait=False)