|
91 | 91 | AutoWeightsLoader,
|
92 | 92 | WeightsMapper,
|
93 | 93 | maybe_prefix,
|
94 |
| - merge_multimodal_embeddings, |
| 94 | + _merge_multimodal_embeddings, |
95 | 95 | )
|
96 | 96 | from .vision import get_vit_attn_backend
|
97 | 97 |
|
@@ -1123,46 +1123,60 @@ def get_input_embeddings(
|
1123 | 1123 | self,
|
1124 | 1124 | input_ids: torch.Tensor,
|
1125 | 1125 | multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
| 1126 | + *, |
| 1127 | + is_multimodal: Optional[torch.Tensor] = None, |
| 1128 | + handle_oov_mm_token: bool = False, |
1126 | 1129 | ) -> torch.Tensor:
|
1127 |
| - inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
| 1130 | + inputs_embeds = self._get_text_embeddings( |
| 1131 | + input_ids, |
| 1132 | + self.language_model.get_input_embeddings, |
| 1133 | + is_multimodal=is_multimodal, |
| 1134 | + handle_oov_mm_token=handle_oov_mm_token, |
| 1135 | + ) |
| 1136 | + |
| 1137 | + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: |
| 1138 | + return inputs_embeds |
| 1139 | + |
| 1140 | + if is_multimodal is None: |
| 1141 | + raise ValueError( |
| 1142 | + "`get_input_embeddings` now requires `is_multimodal` arg, " |
| 1143 | + "please update your model runner according to " |
| 1144 | + "https://github.com/vllm-project/vllm/pull/16229.") |
| 1145 | + |
1128 | 1146 | deepstack_input_embeds = None
|
1129 |
| - if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: |
1130 |
| - # TODO (ywang96): support overlapping modalitiy embeddings so that |
1131 |
| - # `use_audio_in_video` will work on V1. |
1132 |
| - # split the feat dim to obtain multi-scale visual feature |
1133 |
| - if self.visual.deepstack_visual_indexes is not None: |
1134 |
| - multiscale_len = len(self.visual.deepstack_visual_indexes) |
1135 |
| - multimodal_embeddings_multiscale = [] |
1136 |
| - for index, embeddings in enumerate(multimodal_embeddings): |
1137 |
| - if embeddings.shape[-1] != self.config.text_config.hidden_size: |
1138 |
| - visual_dim = embeddings.shape[-1] // (multiscale_len + 1) |
1139 |
| - main_dim, multi_dim = visual_dim, visual_dim * multiscale_len |
1140 |
| - embeddings_main, embeddings_multiscale = torch.split( |
1141 |
| - embeddings, [main_dim, multi_dim], dim=-1 |
1142 |
| - ) |
1143 |
| - multimodal_embeddings[index] = embeddings_main |
1144 |
| - multimodal_embeddings_multiscale.append(embeddings_multiscale) |
1145 |
| - if len(multimodal_embeddings_multiscale) > 0: |
1146 |
| - deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)) |
1147 |
| - deepstack_input_embeds = merge_multimodal_embeddings( |
1148 |
| - input_ids, |
1149 |
| - deepstack_input_embeds, |
1150 |
| - multimodal_embeddings_multiscale, |
1151 |
| - placeholder_token_id=[self.config.image_token_id, self.config.video_token_id], |
| 1147 | + # TODO (ywang96): support overlapping modalitiy embeddings so that |
| 1148 | + # `use_audio_in_video` will work on V1. |
| 1149 | + # split the feat dim to obtain multi-scale visual feature |
| 1150 | + if self.visual.deepstack_visual_indexes is not None: |
| 1151 | + multiscale_len = len(self.visual.deepstack_visual_indexes) |
| 1152 | + multimodal_embeddings_multiscale = [] |
| 1153 | + for index, embeddings in enumerate(multimodal_embeddings): |
| 1154 | + if embeddings.shape[-1] != self.config.text_config.hidden_size: |
| 1155 | + visual_dim = embeddings.shape[-1] // (multiscale_len + 1) |
| 1156 | + main_dim, multi_dim = visual_dim, visual_dim * multiscale_len |
| 1157 | + embeddings_main, embeddings_multiscale = torch.split( |
| 1158 | + embeddings, [main_dim, multi_dim], dim=-1 |
1152 | 1159 | )
|
1153 |
| - deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], multiscale_len, visual_dim).permute(1,0,2).contiguous() |
1154 |
| - self._set_deepstack_input_embeds(deepstack_input_embeds) |
1155 |
| - |
1156 |
| - inputs_embeds = merge_multimodal_embeddings( |
1157 |
| - input_ids, |
1158 |
| - inputs_embeds, |
1159 |
| - multimodal_embeddings, |
1160 |
| - [ |
1161 |
| - self.config.image_token_id, |
1162 |
| - self.config.video_token_id, |
1163 |
| - self.config.audio_token_id, |
1164 |
| - ], |
1165 |
| - ) |
| 1160 | + multimodal_embeddings[index] = embeddings_main |
| 1161 | + multimodal_embeddings_multiscale.append(embeddings_multiscale) |
| 1162 | + |
| 1163 | + # NOTE: This branch should only be triggered for image/video, |
| 1164 | + # but not audio-only inputs |
| 1165 | + if len(multimodal_embeddings_multiscale) > 0: |
| 1166 | + deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)) |
| 1167 | + deepstack_input_embeds = _merge_multimodal_embeddings( |
| 1168 | + inputs_embeds=deepstack_input_embeds, |
| 1169 | + multimodal_embeddings=multimodal_embeddings_multiscale, |
| 1170 | + is_multimodal=is_multimodal, |
| 1171 | + ) |
| 1172 | + deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], multiscale_len, visual_dim).permute(1,0,2).contiguous() |
| 1173 | + self._set_deepstack_input_embeds(deepstack_input_embeds) |
| 1174 | + |
| 1175 | + inputs_embeds = _merge_multimodal_embeddings( |
| 1176 | + inputs_embeds=inputs_embeds, |
| 1177 | + multimodal_embeddings=multimodal_embeddings, |
| 1178 | + is_multimodal=is_multimodal, |
| 1179 | + ) |
1166 | 1180 |
|
1167 | 1181 | return inputs_embeds
|
1168 | 1182 |
|
|
0 commit comments