diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c8b6c6c86120..ec20375411e4 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -612,6 +612,7 @@ Specified using `--task generate`. | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | ✅︎ | | `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ | | `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 8014cb53f16a..01d6a188be99 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -190,6 +190,37 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: ) +def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData: + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + show how to process audio inputs. + """ + model_path = snapshot_download( + "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" + ) + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + speech_lora_path = os.path.join(model_path, "speech-lora") + placeholders = "<|audio|>" * audio_count + + prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>" + + engine_args = EngineArgs( + model=model_path, + max_model_len=12800, + max_num_seqs=2, + enable_lora=True, + max_lora_rank=320, + limit_mm_per_prompt={"audio": audio_count}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompts, + lora_requests=[LoRARequest("speech", 1, speech_lora_path)], + ) + + # Qwen2-Audio def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: model_name = "Qwen/Qwen2-Audio-7B-Instruct" @@ -303,6 +334,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: "granite_speech": run_granite_speech, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, + "phi4_multimodal": run_phi4_multimodal, "qwen2_audio": run_qwen2_audio, "qwen2_5_omni": run_qwen2_5_omni, "ultravox": run_ultravox, diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e4811c023377..bbe7541e0182 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -987,6 +987,41 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: ) +# HF format Phi-4-multimodal-instruct +def run_phi4_multimodal(questions: list[str], modality: str) -> ModelRequestData: + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + show how to process image inputs. + """ + assert modality == "image" + model_path = snapshot_download( + "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" + ) + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + vision_lora_path = os.path.join(model_path, "vision-lora") + prompts = [ + f"<|user|><|image|>{question}<|end|><|assistant|>" for question in questions + ] + engine_args = EngineArgs( + model=model_path, + max_model_len=5120, + max_num_seqs=2, + max_num_batched_tokens=12800, + enable_lora=True, + max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 16}, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + lora_requests=[LoRARequest("vision", 1, vision_lora_path)], + ) + + # Pixtral HF-format def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1244,6 +1279,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "paligemma2": run_paligemma2, "phi3_v": run_phi3v, "phi4_mm": run_phi4mm, + "phi4_multimodal": run_phi4_multimodal, "pixtral_hf": run_pixtral_hf, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index eb4f3b6c8f44..385206b525fe 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -686,6 +686,40 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_phi4_multimodal(question: str, image_urls: list[str]) -> ModelRequestData: + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + show how to process multi images inputs. + """ + + model_path = snapshot_download( + "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" + ) + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + vision_lora_path = os.path.join(model_path, "vision-lora") + engine_args = EngineArgs( + model=model_path, + max_model_len=4096, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enable_lora=True, + max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 4}, + ) + + placeholders = "<|image|>" * len(image_urls) + prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>" + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + lora_requests=[LoRARequest("vision", 1, vision_lora_path)], + ) + + def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" engine_args = EngineArgs( @@ -912,6 +946,7 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: "ovis": load_ovis, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, + "phi4_multimodal": load_phi4_multimodal, "pixtral_hf": load_pixtral_hf, "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py new file mode 100644 index 000000000000..db8984d8656f --- /dev/null +++ b/tests/models/multimodal/generation/test_phi4_multimodal.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from collections.abc import Sequence +from typing import Optional + +import librosa +import pytest +from huggingface_hub import snapshot_download + +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest +from vllm.multimodal.image import rescale_image_size +from vllm.platforms import current_platform + +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, + PromptImageInput, VllmRunner) +from ....utils import large_gpu_test +from ...utils import check_logprobs_close + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": + "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 +}) +HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 + +model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct", + revision="refs/pr/70") +# Since the vision-lora and speech-lora co-exist with the base model, +# we have to manually specify the path of the lora weights. +vision_lora_path = os.path.join(model_path, "vision-lora") +speech_question = os.path.join(model_path, "examples", + "what_is_shown_in_this_image.wav") +models = [model_path] + +target_dtype = "half" + +# ROCm Triton FA can run into shared memory issues with these models, +# use other backends in the meantime +# FIXME (mattwong, gshtrasb, hongxiayan) +if current_platform.is_rocm(): + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: Sequence[tuple[list[str], PromptImageInput, + Optional[PromptAudioInput]]], + model: str, + *, + max_model_len: int, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size + with vllm_runner( + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, + trust_remote_code=False, + ) as vllm_model: + lora_request = LoRARequest("vision", 1, vision_lora_path) + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request) + for prompts, images, audios in inputs + ] + + with hf_runner(model, dtype=dtype) as hf_model: + hf_model.model.load_adapter( + vision_lora_path, + adapter_name="vision", + ) + hf_processor = hf_model.processor + eos_token_id = hf_processor.tokenizer.eos_token_id + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id) + for prompts, images, audios in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_model_len", [12800]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_model_len: int, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + run_test( + hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + # [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_model_len", [25600]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_model_len: int, + max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ( + [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors], + None, + ), + ] + + run_test( + hf_runner, + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=2, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_model_len", [12800]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, + max_model_len: int, max_tokens: int, + num_logprobs: int) -> None: + + # use the example speech question so that the model outputs are reasonable + audio = librosa.load(speech_question, sr=16000) + image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + + inputs_vision_speech = [ + ( + ["<|user|><|image|><|audio|><|end|><|assistant|>"], + [image], + [audio], + ), + ] + + run_test( + hf_runner, + vllm_runner, + inputs_vision_speech, + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index fd5842523178..627c6715fbc7 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -41,12 +41,18 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: def _test_processing_correctness( - model_id: str, + model_id_or_arch: str, hit_rate: float, num_batches: int, simplify_rate: float, ): - model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + if model_id_or_arch in HF_EXAMPLE_MODELS.get_supported_archs(): + # Use model architecture to get the default model id + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_id_or_arch) + model_id = model_info.default + else: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch) + model_id = model_id_or_arch model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") @@ -58,7 +64,7 @@ def _test_processing_correctness( trust_remote_code=model_info.trust_remote_code, seed=0, dtype="auto", - revision=None, + revision=model_info.revision, hf_overrides=model_info.hf_overrides, ) @@ -330,6 +336,28 @@ def test_processing_correctness( ) +# Phi4MultimodalForCausalLM share same model repo with original format +# Phi4MMForCausalLM, so we add it as a separate test case +# Remove this test after conversion PR merged: +# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70 +@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +def test_processing_correctness_phi4_multimodal( + model_arch: str, + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + _test_processing_correctness( + model_arch, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) + + def _assert_inputs_equal( a: MultiModalInputs, b: MultiModalInputs, diff --git a/tests/models/registry.py b/tests/models/registry.py index 84ca0bc60003..574f0a3ff20f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -425,6 +425,8 @@ def check_available_online( "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), + "Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501 + revision="refs/pr/70"), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 tokenizer_mode="mistral"), "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py new file mode 100644 index 000000000000..432b707a6159 --- /dev/null +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -0,0 +1,1455 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (BatchFeature, Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, Phi4MultimodalFeatureExtractor, + Phi4MultimodalImageProcessorFast) +from transformers import Phi4MultimodalProcessor as Phi4MMProcessor +from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( + Phi4MultimodalAudioConvModule, Phi4MultimodalAudioNemoConvSubsampling, + Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor) + +from vllm.config import VllmConfig +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, + ImageProcessorItems, ImageSize, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +# <|endoftext10|> (see vocab.json in hf model) +_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 +# <|endoftext11|> +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 + +_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 + + +def _get_padding_size(orig_width: int, orig_height: int, target_height: int, + target_width: int): + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + + if ratio_width < ratio_height: + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + return padding_height, padding_width + + +class Phi4MMProjector(nn.Module): + + def __init__(self, input_size: int, hidden_size: int): + super().__init__() + self.up = ColumnParallelLinear(input_size, hidden_size) + self.down = RowParallelLinear(hidden_size, hidden_size) + self.act = get_act_fn("gelu") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up(x) + x = self.act(x) + x, _ = self.down(x) + return x + + +class Phi4MMImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = (config.vision_config.image_size // + config.vision_config.patch_size) + if n_patches % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2)**2 + + num_hidden_layers = (config.vision_config.num_hidden_layers + + self.layer_idx + + 1 if self.layer_idx < 0 else self.layer_idx + 1) + self.img_processor = Idefics2VisionTransformer( + config.vision_config, + require_post_norm=False, + num_hidden_layers_override=num_hidden_layers) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection = Phi4MMProjector(self.image_dim_out, + config.hidden_size) + self.global_img_feature_extensor = nn.Parameter( + torch.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = nn.Parameter( + torch.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features( + self, + img_embeds: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + img_feature = self.img_processor(img_embeds, + patch_attention_mask=attention_mask) + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, + patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, + patch_feature.size(1) * patch_feature.size(2), + patch_feature.size(-1)) + return patch_feature + + def forward( + self, + image_pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + image_pixel_values = image_pixel_values.to( + self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_device = self.img_projection.up.bias.device + target_dtype = self.img_projection.up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to( + dtype=bool, device=target_device), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, + self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, + self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat( + 1, base_feat_size, 1, 1) + global_img = torch.cat([global_img, temporary_extensor], + dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = (sub_img.reshape( + height_ratio, width_ratio, base_feat_size, base_feat_size, + self.image_dim_out).transpose(1, 2).reshape( + 1, height_ratio * base_feat_size, + width_ratio * base_feat_size, + self.image_dim_out).contiguous()) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1:area_ratio + 1, + 0::2, 0::2].reshape( + height_ratio, width_ratio, + base_feat_size, + base_feat_size).transpose( + 1, 2).reshape( + 1, height_ratio * + base_feat_size, + width_ratio * + base_feat_size)) + useful_height = int( + reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int( + reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat( + 1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat( + 1, height_ratio * base_feat_size, 1, 1) + + sub_img = torch.cat([sub_img, temporary_extensor], + dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append( + torch.cat( + [sub_img, self.global_img_feature_extensor, global_img], + dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(device=target_device, + dtype=target_dtype) + img_feature_proj = self.img_projection(output_img) + img_set_tensor.append(img_feature_proj.flatten(0, 1)) + + return img_set_tensor + + +class Phi4MultimodalAudioMLP(nn.Module): + + def __init__( + self, + config: Phi4MultimodalAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = MulAndSilu() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [config.intermediate_size] * 2, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + hidden_states, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class Phi4MultimodalAudioAttention(nn.Module): + + def __init__( + self, + config: Phi4MultimodalAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.total_num_heads + if self.head_dim * self.total_num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_heads = divide(self.total_num_heads, self.tp_size) + + def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: + start_idx = self.num_heads * self.tp_rank + end_idx = self.num_heads * (self.tp_rank + 1) + return attention_mask[:, start_idx:end_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + qkv_states, _ = self.qkv_proj(hidden_states) + query, key, value = qkv_states.chunk(3, dim=-1) + + bsz, seq_len, _ = query.size() + query = query.view(bsz, seq_len, self.num_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_heads, self.head_dim) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + + attention_mask = self.split_attn_mask(attention_mask) + out = F.scaled_dot_product_attention( + query, + key, + value, + scale=self.scale, + attn_mask=attention_mask, + ) + out = out.transpose(1, 2).reshape(bsz, seq_len, -1) + + attn_output, _ = self.o_proj(out) + + return attn_output + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): + + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + + hidden_states = residual + self.self_attn(hidden_states, + attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out( + hidden_states) + + out = self.layer_norm(hidden_states) + + return out + + +class Phi4MMAudioMeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.global_mean = nn.Parameter(torch.zeros(config.input_size)) + self.global_invstd = nn.Parameter(torch.ones(config.input_size)) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + +class Phi4MultimodalAudioModel(nn.Module): + + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + + self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = ( + Phi4MultimodalAudioRelativeAttentionBias(config)) + self.encoders = nn.ModuleList([ + Phi4MultimodalAudioConformerEncoderLayer(config) + for _ in range(config.num_blocks) + ]) + + def _streaming_mask( + self, + seq_len: int, + batch_size: int, + chunk_size: int, + left_chunk: int, + ): + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size) + + enc_streaming_mask = (adaptive_enc_mask( + seq_len, chunk_start_idx, + left_window=left_chunk).unsqueeze(0).expand([batch_size, -1, -1])) + return enc_streaming_mask + + def forward_embeddings( + self, + hidden_states: torch.Tensor, + masks: torch.Tensor, + ): + """Forwarding the inputs through the top embedding layers""" + seq_len = math.ceil(hidden_states.shape[1] / + self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"Sequence length after time reduction is invalid: {seq_len}." + "Your input feature is too short.") + + batch_size = hidden_states.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, + self.config.chunk_size, + self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states: torch.Tensor, + device: torch.device, mask: torch.Tensor): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, + self.config.chunk_size, + self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + def forward(self, + hidden_states: torch.Tensor, + mask: Optional[torch.Tensor] = None): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings( + hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 # maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, + # unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, + # pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states_pad = F.pad(hidden_states, + (0, 0, 0, chunk_pad_size), + "constant", 0) + hidden_states = hidden_states_pad.to(hidden_states.device) + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = mask.squeeze( + 1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", + False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = ( + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()) + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1).bool() # unfold op does not support bool tensor + hs_mask = self.calculate_hs_mask( + hidden_states, hidden_states.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + relative_attention_bias = self.relative_attention_bias_layer( + hidden_states) + attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias + + for layer in self.encoders: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape(bs, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +class Phi4MMAudioEmbedding(nn.Module): + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.encoder = Phi4MultimodalAudioModel(config.audio_config) + + audio_config = config.audio_config + proj_input_size = (audio_config.hidden_size * + audio_config.downsample_rate) + self.vision_speech_projection = Phi4MMProjector( + proj_input_size, config.hidden_size) + self.speech_projection = Phi4MMProjector(proj_input_size, + config.hidden_size) + + def get_projection( + self, + audio_projection_mode: Literal["speech", "vision"], + ) -> Phi4MMProjector: + if audio_projection_mode == "speech": + return self.speech_projection + elif audio_projection_mode == "vision": + return self.vision_speech_projection + + def forward( + self, + audio_input_features: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> torch.FloatTensor: + + audio_projection = self.get_projection(audio_projection_mode) + + target_device = audio_projection.up.bias.device + target_dtype = audio_projection.up.bias.dtype + + audio_input_features = audio_input_features.to(device=target_device, + dtype=target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, + audio_attention_mask) + audio_embeds = audio_projection(audio_encoder_hidden_states) + + return audio_embeds.flatten(0, 1) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Phi4MMImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + num_img_tokens: list[int] + """Shape: `(batch_size * num_images)`""" + + image_attention_mask: torch.Tensor + """Shape: `(batch_size * num_images, H_mask, W_mask)`""" + + +class Phi4MMImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size * num_audios, 80, M)""" + + +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + +Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] + + +def cat_with_pad(tensors, dim, padding_value=0): + """ + cat along dim, while pad to max for all other dims + """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" + + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) + + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) + + output[slices] = t + index += t.shape[dim] + + return output + + +class Phi4MMProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> Phi4MultimodalConfig: + return self.ctx.get_hf_config(Phi4MultimodalConfig) + + def get_hf_processor( + self, + *, + dynamic_hd: Optional[int] = None, + **kwargs: object, + ) -> Phi4MMProcessor: + if dynamic_hd is not None: + kwargs["dynamic_hd"] = dynamic_hd + + return self.ctx.get_hf_processor(**kwargs) + + def get_feature_extractor(self) -> Phi4MultimodalFeatureExtractor: + return self.get_hf_processor().audio_processor + + def get_image_processor( + self, + processor: Optional[Phi4MMProcessor] = None, + ) -> Phi4MultimodalImageProcessorFast: + if processor is None: + processor = self.get_hf_processor() + return processor.image_processor + + def get_dynamic_hd( + self, + processor: Optional[Phi4MMProcessor] = None, + ) -> int: + return self.get_image_processor(processor).dynamic_hd + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None} + + def _find_target_aspect_ratio( + self, + orig_width: int, + orig_height: int, + image_size: int, + max_num: int, + min_num: int, + ): + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + image_processor = self.get_image_processor() + target_aspect_ratio = image_processor.find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + orig_width, + orig_height, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + + def _compute_num_image_tokens( + self, + orig_width: int, + orig_height: int, + dynamic_hd_size: int, + vit_image_size: int, + vit_patch_size: int, + token_compression_factor: int = 2, + ): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels + + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, ( + "vit_image_size must be divisible by vit_patch_size") + assert (vit_image_size // vit_patch_size % + token_compression_factor == 0), ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor") + + target_aspect_ratio, target_height, target_width = ( + self._find_target_aspect_ratio(orig_width, + orig_height, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[0] * vit_image_size == target_width, ( + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + assert target_aspect_ratio[1] * vit_image_size == target_height, ( + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size( + orig_width, orig_height, target_height, target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // + token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + + num_hd_patch_tokens + num_hd_newline_tokens + + num_global_image_newline_tokens) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Phi4MMProcessor] = None, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + vit_image_size = vision_config.image_size + vit_patch_size = vision_config.patch_size + + dynamic_hd_size = self.get_dynamic_hd(processor=processor) + + # we use default `token_compression_factor=2`, + # since it's not in HF vision config. + image_num_tokens = self._compute_num_image_tokens( + image_width, + image_height, + dynamic_hd_size=dynamic_hd_size, + vit_image_size=vit_image_size, + vit_patch_size=vit_patch_size, + ) + + return image_num_tokens + + def get_image_size_with_most_features( + self, + processor: Optional[Phi4MMProcessor] = None, + ) -> ImageSize: + vit_image_size = self.get_hf_config().vision_config.image_size + + max_side = vit_image_size * self.get_dynamic_hd(processor=processor) + return ImageSize(height=max_side, width=vit_image_size) + + def get_audio_num_frames(self, audio_len: int, sr: float) -> int: + """ + Compute the output size of the `extract_features` method. + + Args: + audio_len (int): Length of the input waveform in samples. + sr (float): Sampling rate of the waveform, either 16000 or 8000. + + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ + + # Resample to 16000 or 8000 if needed + if sr > 16000: + audio_len //= sr // 16000 + elif 8000 <= sr < 16000: + # We'll resample to 16K from 8K + audio_len *= 2 + elif sr < 8000: + raise RuntimeError(f"Unsupported sample rate {sr}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + + # Calculate number of frames (T) + num_frames = (audio_len - win_length) // hop_length + 1 + if num_frames < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) + return num_frames + + def _compute_audio_embed_size(self, audio_frames: int) -> int: + """ + Compute the size of audio embeddings from the number of audio frames. + """ + # `_compute_audio_embed_size` in audio_processor use torch for + # computation, therefore we re-implement it to use pythonic + # numeric computation to avoid extra tensor conversion. + audio_processor = self.get_feature_extractor() + audio_compression_rate = audio_processor.audio_compression_rate + audio_downsample_rate = audio_processor.audio_downsample_rate + + integer = audio_frames // audio_compression_rate + remainder = audio_frames % audio_compression_rate + result = integer + int(remainder > 0) + + integer = result // audio_downsample_rate + remainder = result % audio_downsample_rate + result = integer + int(remainder > 0) # qformer compression + + return result + + +class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + + tokenizer = self.info.get_tokenizer() + image_tokens: str = tokenizer.image_token * num_images + audio_tokens: str = tokenizer.audio_token * num_audios + + return image_tokens + audio_tokens + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios), + } + + return mm_data + + +class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + audio_data = mm_data.pop("audios", []) + if audio_data: + mm_data['audio'] = audio_data + + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs, tok_kwargs) + + if "image_pixel_values" in processed_outputs: + num_img_tokens = [ + self.info.get_num_image_tokens(image_width=img_size[0], + image_height=img_size[1]) + for img_size in processed_outputs["image_sizes"] + ] + processed_outputs["num_img_tokens"] = num_img_tokens + + if audio_data: + audio_features = processed_outputs['audio_input_features'] + sr = self.info.get_feature_extractor().sampling_rate + feature_sizes = [ + self.info.get_audio_num_frames(len(audio), sr) + for audio in audio_data + ] + processed_outputs['audio_input_features'] = [ + audio_features[idx, :size] + for idx, size in enumerate(feature_sizes) + ] + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + image_pixel_values=MultiModalFieldConfig.batched("image"), + image_attention_mask=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + num_img_tokens=MultiModalFieldConfig.batched("image"), + audio_input_features=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.vocab[tokenizer.image_token] + audio_token_id = tokenizer.vocab[tokenizer.audio_token] + + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + audio_processor = self.info.get_feature_extractor() + + def get_image_replacement_phi4mm(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + image_tokens = [image_token_id] * num_image_tokens + + return image_tokens + + def get_audio_replacement_phi4mm(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + # TODO(Isotr0py): support embedding inputs + audio_len = audios.get_audio_length(item_idx) + audio_frames = self.info.get_audio_num_frames( + audio_len, audio_processor.sampling_rate) + audio_embed_size = self.info._compute_audio_embed_size( + audio_frames) + + audio_tokens = [audio_token_id] * audio_embed_size + + return audio_tokens + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_audio_replacement_phi4mm, + ), + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_image_replacement_phi4mm, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Phi4MMMultiModalProcessor, + info=Phi4MMProcessingInfo, + dummy_inputs=Phi4MMDummyInputsBuilder, +) +class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): + """ + Implements the Phi-4-multimodal-instruct model in vLLM. + """ + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Multimodal embedding + "model.embed_tokens_extend.": "", + # LLM backbone + "model.": "language_model.model.", + }, + orig_to_new_substr={ + # projection + ".img_projection_": ".img_projection.", + ".up_proj_for_speech.": ".speech_projection.up.", + ".up_proj_for_vision_speech.": ".vision_speech_projection.up.", + ".down_proj_for_speech.": ".speech_projection.down.", + ".down_proj_for_vision_speech.": ".vision_speech_projection.down.", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|image|>" + if modality.startswith("audio"): + return "<|audio|>" + + raise ValueError("Only image or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes these for supporting input embeddings. + self.image_embed = Phi4MMImageEmbedding( + config, + # prefix=maybe_prefix(prefix, "image_embed"), + ) + self.audio_embed = Phi4MMAudioEmbedding( + config, + # prefix=maybe_prefix(prefix, "audio_embed"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Phi3ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + """ + Parse and validate the audio input to the model. This handles both + audio features and audio embeddings, but only the former is used for + now. + + Args: + kwargs (object): Keyword arguments. + + Returns: + Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. + """ + audio_features = kwargs.pop("audio_input_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return Phi4MMAudioFeatureInputs(type="audio_features", + data=flatten_bn(audio_features)) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input(self, audio_input: Phi4MMAudioInputs, + audio_projection_mode: str) -> NestedTensors: + """ + Create the audio embeddings from the audio input, where the audio input + is pairs of audio features and audio embed lengths. The audio input is + created by `input_mapper_for_phi4mm_audio`. + + Args: + audio_input (Phi4MMAudioInputs): Audio input. + + Returns: + NestedTensors: Audio embeddings + """ + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + + dtype = next(self.audio_embed.parameters()).dtype + audio_embeds = [ + self.audio_embed( + features.unsqueeze(0).to(dtype), + audio_projection_mode=audio_projection_mode, + ) for features in audio_features + ] + return audio_embeds + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: + image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") + if image_pixel_values is None: + return None + + image_sizes = kwargs.get("image_sizes") + image_attention_mask = kwargs.get("image_attention_mask") + num_img_tokens = kwargs.get("num_img_tokens") + assert image_sizes is not None and image_attention_mask is not None\ + and num_img_tokens is not None, "Missing image inputs" + + if is_list_of(image_pixel_values, torch.Tensor): + assert all(p.dim() == 5 + for p in image_pixel_values), "Incorrect image inputs" + # list len is batch_size. + # each tensor has dimension: num_img_per_example, num_hd_patches, + # channels, height, width. + # need to pad along num_hd_patches. + # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. + image_pixel_values = cat_with_pad(image_pixel_values, dim=0) + elif isinstance(image_pixel_values, torch.Tensor): + # dimension: batch_size, num_img_per_example, num_hd_patches, + # channels, height, width. + # we flatten first 2 dims to make it a single large batch for + # SigLIP Encoder. + assert image_pixel_values.dim() == 6, "Incorrect image inputs" + image_pixel_values = image_pixel_values.flatten(0, 1) + else: + raise ValueError("Incorrect image_pixel_values inputs") + + if isinstance(image_attention_mask, list): + image_attention_mask = cat_with_pad(image_attention_mask, dim=0) + elif isinstance(image_attention_mask, torch.Tensor): + image_attention_mask = image_attention_mask.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(image_sizes, list): + image_sizes = torch.cat(image_sizes, dim=0) + elif isinstance(image_sizes, torch.Tensor): + image_sizes = image_sizes.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(num_img_tokens, list): + num_img_tokens = [ + n for num_tensor in num_img_tokens + for n in num_tensor.tolist() + ] + elif isinstance(num_img_tokens, torch.Tensor): + num_img_tokens = num_img_tokens.flatten(0, 1).tolist() + else: + raise ValueError("Incorrect image_attention_mask inputs") + + return Phi4MMImagePixelInputs( + type="pixel_values", + data=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + num_img_tokens=num_img_tokens, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("image_pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("audio_input_features", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_image_input( + self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + dtype = next(self.image_embed.parameters()).dtype + pixel_values = image_input['data'].to(dtype) + image_sizes = image_input['image_sizes'] + image_attention_mask = image_input['image_attention_mask'] + image_embeds = self.image_embed(pixel_values, image_sizes, + image_attention_mask) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + audio_projection_mode = 'speech' + for modality in modalities: + # make sure process images first + if modality == "images": + audio_projection_mode = "vision" + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "audios": + audio_input = modalities["audios"] + audio_embeddings = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + multimodal_embeddings += tuple(audio_embeddings) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Phi4MMImagePixelInputs] = None, + audio_input: Optional[Phi4MMAudioFeatureInputs] = None, + ) -> torch.Tensor: + audio_projection_mode = 'speech' + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, + ) + audio_projection_mode = 'vision' + + if audio_input is not None: + audio_embeds = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + audio_embeds, + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + audio_input = self._parse_and_validate_audio_input(**kwargs) + + if image_input is None and audio_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + audio_input=audio_input) + input_ids = None + + hidden_states = self.language_model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model.", + connector=[ + "img_projection", "vision_speech_projection", + "speech_projection" + ], + tower_model=["image_embed", "audio_embed"], + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fafb6a704383..f33a69235d46 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -220,6 +220,8 @@ "Ovis": ("ovis", "Ovis"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), + "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 @@ -228,7 +230,6 @@ "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), - "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 25dd71d877fb..24ddd35abea6 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -295,7 +295,7 @@ def cached_tokenizer_from_config( return cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode, - tokenizer_revision=model_config.tokenizer_revision, + revision=model_config.tokenizer_revision, trust_remote_code=model_config.trust_remote_code, **kwargs, )