Skip to content

Commit f0d057a

Browse files
committed
Update w.r.t. vllm-project#16229
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 81fd24b commit f0d057a

File tree

1 file changed

+52
-38
lines changed

1 file changed

+52
-38
lines changed

vllm/model_executor/models/qwen3_omni_moe_thinker.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
AutoWeightsLoader,
9292
WeightsMapper,
9393
maybe_prefix,
94-
merge_multimodal_embeddings,
94+
_merge_multimodal_embeddings,
9595
)
9696
from .vision import get_vit_attn_backend
9797

@@ -1123,46 +1123,60 @@ def get_input_embeddings(
11231123
self,
11241124
input_ids: torch.Tensor,
11251125
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1126+
*,
1127+
is_multimodal: Optional[torch.Tensor] = None,
1128+
handle_oov_mm_token: bool = False,
11261129
) -> torch.Tensor:
1127-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1130+
inputs_embeds = self._get_text_embeddings(
1131+
input_ids,
1132+
self.language_model.get_input_embeddings,
1133+
is_multimodal=is_multimodal,
1134+
handle_oov_mm_token=handle_oov_mm_token,
1135+
)
1136+
1137+
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
1138+
return inputs_embeds
1139+
1140+
if is_multimodal is None:
1141+
raise ValueError(
1142+
"`get_input_embeddings` now requires `is_multimodal` arg, "
1143+
"please update your model runner according to "
1144+
"https://github.com/vllm-project/vllm/pull/16229.")
1145+
11281146
deepstack_input_embeds = None
1129-
if multimodal_embeddings is not None and len(multimodal_embeddings) != 0:
1130-
# TODO (ywang96): support overlapping modalitiy embeddings so that
1131-
# `use_audio_in_video` will work on V1.
1132-
# split the feat dim to obtain multi-scale visual feature
1133-
if self.visual.deepstack_visual_indexes is not None:
1134-
multiscale_len = len(self.visual.deepstack_visual_indexes)
1135-
multimodal_embeddings_multiscale = []
1136-
for index, embeddings in enumerate(multimodal_embeddings):
1137-
if embeddings.shape[-1] != self.config.text_config.hidden_size:
1138-
visual_dim = embeddings.shape[-1] // (multiscale_len + 1)
1139-
main_dim, multi_dim = visual_dim, visual_dim * multiscale_len
1140-
embeddings_main, embeddings_multiscale = torch.split(
1141-
embeddings, [main_dim, multi_dim], dim=-1
1142-
)
1143-
multimodal_embeddings[index] = embeddings_main
1144-
multimodal_embeddings_multiscale.append(embeddings_multiscale)
1145-
if len(multimodal_embeddings_multiscale) > 0:
1146-
deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1))
1147-
deepstack_input_embeds = merge_multimodal_embeddings(
1148-
input_ids,
1149-
deepstack_input_embeds,
1150-
multimodal_embeddings_multiscale,
1151-
placeholder_token_id=[self.config.image_token_id, self.config.video_token_id],
1147+
# TODO (ywang96): support overlapping modalitiy embeddings so that
1148+
# `use_audio_in_video` will work on V1.
1149+
# split the feat dim to obtain multi-scale visual feature
1150+
if self.visual.deepstack_visual_indexes is not None:
1151+
multiscale_len = len(self.visual.deepstack_visual_indexes)
1152+
multimodal_embeddings_multiscale = []
1153+
for index, embeddings in enumerate(multimodal_embeddings):
1154+
if embeddings.shape[-1] != self.config.text_config.hidden_size:
1155+
visual_dim = embeddings.shape[-1] // (multiscale_len + 1)
1156+
main_dim, multi_dim = visual_dim, visual_dim * multiscale_len
1157+
embeddings_main, embeddings_multiscale = torch.split(
1158+
embeddings, [main_dim, multi_dim], dim=-1
11521159
)
1153-
deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], multiscale_len, visual_dim).permute(1,0,2).contiguous()
1154-
self._set_deepstack_input_embeds(deepstack_input_embeds)
1155-
1156-
inputs_embeds = merge_multimodal_embeddings(
1157-
input_ids,
1158-
inputs_embeds,
1159-
multimodal_embeddings,
1160-
[
1161-
self.config.image_token_id,
1162-
self.config.video_token_id,
1163-
self.config.audio_token_id,
1164-
],
1165-
)
1160+
multimodal_embeddings[index] = embeddings_main
1161+
multimodal_embeddings_multiscale.append(embeddings_multiscale)
1162+
1163+
# NOTE: This branch should only be triggered for image/video,
1164+
# but not audio-only inputs
1165+
if len(multimodal_embeddings_multiscale) > 0:
1166+
deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1))
1167+
deepstack_input_embeds = _merge_multimodal_embeddings(
1168+
inputs_embeds=deepstack_input_embeds,
1169+
multimodal_embeddings=multimodal_embeddings_multiscale,
1170+
is_multimodal=is_multimodal,
1171+
)
1172+
deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], multiscale_len, visual_dim).permute(1,0,2).contiguous()
1173+
self._set_deepstack_input_embeds(deepstack_input_embeds)
1174+
1175+
inputs_embeds = _merge_multimodal_embeddings(
1176+
inputs_embeds=inputs_embeds,
1177+
multimodal_embeddings=multimodal_embeddings,
1178+
is_multimodal=is_multimodal,
1179+
)
11661180

11671181
return inputs_embeds
11681182

0 commit comments

Comments
 (0)