diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index c8b6c6c86120..ec20375411e4 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -612,6 +612,7 @@ Specified using `--task generate`.
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | ✅︎ |
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ |
| `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 8014cb53f16a..01d6a188be99 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -190,6 +190,37 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
)
+def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData:
+ """
+ Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
+ show how to process audio inputs.
+ """
+ model_path = snapshot_download(
+ "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
+ )
+ # Since the vision-lora and speech-lora co-exist with the base model,
+ # we have to manually specify the path of the lora weights.
+ speech_lora_path = os.path.join(model_path, "speech-lora")
+ placeholders = "<|audio|>" * audio_count
+
+ prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
+
+ engine_args = EngineArgs(
+ model=model_path,
+ max_model_len=12800,
+ max_num_seqs=2,
+ enable_lora=True,
+ max_lora_rank=320,
+ limit_mm_per_prompt={"audio": audio_count},
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompts,
+ lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
+ )
+
+
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
@@ -303,6 +334,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
+ "phi4_multimodal": run_phi4_multimodal,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox,
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index e4811c023377..bbe7541e0182 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -987,6 +987,41 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
)
+# HF format Phi-4-multimodal-instruct
+def run_phi4_multimodal(questions: list[str], modality: str) -> ModelRequestData:
+ """
+ Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
+ show how to process image inputs.
+ """
+ assert modality == "image"
+ model_path = snapshot_download(
+ "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
+ )
+ # Since the vision-lora and speech-lora co-exist with the base model,
+ # we have to manually specify the path of the lora weights.
+ vision_lora_path = os.path.join(model_path, "vision-lora")
+ prompts = [
+ f"<|user|><|image|>{question}<|end|><|assistant|>" for question in questions
+ ]
+ engine_args = EngineArgs(
+ model=model_path,
+ max_model_len=5120,
+ max_num_seqs=2,
+ max_num_batched_tokens=12800,
+ enable_lora=True,
+ max_lora_rank=320,
+ # Note - mm_processor_kwargs can also be passed to generate/chat calls
+ mm_processor_kwargs={"dynamic_hd": 16},
+ limit_mm_per_prompt={"image": 1},
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
+ )
+
+
# Pixtral HF-format
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1244,6 +1279,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
"phi4_mm": run_phi4mm,
+ "phi4_multimodal": run_phi4_multimodal,
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index eb4f3b6c8f44..385206b525fe 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -686,6 +686,40 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_phi4_multimodal(question: str, image_urls: list[str]) -> ModelRequestData:
+ """
+ Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
+ show how to process multi images inputs.
+ """
+
+ model_path = snapshot_download(
+ "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
+ )
+ # Since the vision-lora and speech-lora co-exist with the base model,
+ # we have to manually specify the path of the lora weights.
+ vision_lora_path = os.path.join(model_path, "vision-lora")
+ engine_args = EngineArgs(
+ model=model_path,
+ max_model_len=4096,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ enable_lora=True,
+ max_lora_rank=320,
+ # Note - mm_processor_kwargs can also be passed to generate/chat calls
+ mm_processor_kwargs={"dynamic_hd": 4},
+ )
+
+ placeholders = "<|image|>" * len(image_urls)
+ prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
+ )
+
+
def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
engine_args = EngineArgs(
@@ -912,6 +946,7 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
"ovis": load_ovis,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
+ "phi4_multimodal": load_phi4_multimodal,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py
new file mode 100644
index 000000000000..db8984d8656f
--- /dev/null
+++ b/tests/models/multimodal/generation/test_phi4_multimodal.py
@@ -0,0 +1,252 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+from collections.abc import Sequence
+from typing import Optional
+
+import librosa
+import pytest
+from huggingface_hub import snapshot_download
+
+from vllm.assets.image import ImageAsset
+from vllm.lora.request import LoRARequest
+from vllm.multimodal.image import rescale_image_size
+from vllm.platforms import current_platform
+
+from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
+ PromptImageInput, VllmRunner)
+from ....utils import large_gpu_test
+from ...utils import check_logprobs_close
+
+HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
+ "stop_sign":
+ "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
+ "cherry_blossom":
+ "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501
+})
+HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
+
+model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct",
+ revision="refs/pr/70")
+# Since the vision-lora and speech-lora co-exist with the base model,
+# we have to manually specify the path of the lora weights.
+vision_lora_path = os.path.join(model_path, "vision-lora")
+speech_question = os.path.join(model_path, "examples",
+ "what_is_shown_in_this_image.wav")
+models = [model_path]
+
+target_dtype = "half"
+
+# ROCm Triton FA can run into shared memory issues with these models,
+# use other backends in the meantime
+# FIXME (mattwong, gshtrasb, hongxiayan)
+if current_platform.is_rocm():
+ os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
+
+
+def run_test(
+ hf_runner: type[HfRunner],
+ vllm_runner: type[VllmRunner],
+ inputs: Sequence[tuple[list[str], PromptImageInput,
+ Optional[PromptAudioInput]]],
+ model: str,
+ *,
+ max_model_len: int,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+ mm_limit: int,
+ tensor_parallel_size: int,
+ distributed_executor_backend: Optional[str] = None,
+):
+ """Inference result should be the same between hf and vllm.
+
+ All the image fixtures for the test are from IMAGE_ASSETS.
+ For huggingface runner, we provide the PIL images as input.
+ For vllm runner, we provide MultiModalDataDict objects
+ and corresponding MultiModalConfig as input.
+ Note, the text input is also adjusted to abide by vllm contract.
+ The text output is sanitized to be able to compare with hf.
+ """
+ # NOTE: take care of the order. run vLLM first, and then run HF.
+ # vLLM needs a fresh new process without cuda initialization.
+ # if we run HF first, the cuda initialization will be done and it
+ # will hurt multiprocessing backend with fork method (the default method).
+ # max_model_len should be greater than image_feature_size
+ with vllm_runner(
+ model,
+ task="generate",
+ max_model_len=max_model_len,
+ max_num_seqs=2,
+ dtype=dtype,
+ limit_mm_per_prompt={"image": mm_limit},
+ tensor_parallel_size=tensor_parallel_size,
+ distributed_executor_backend=distributed_executor_backend,
+ enable_lora=True,
+ max_lora_rank=320,
+ gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
+ enforce_eager=True,
+ trust_remote_code=False,
+ ) as vllm_model:
+ lora_request = LoRARequest("vision", 1, vision_lora_path)
+ vllm_outputs_per_case = [
+ vllm_model.generate_greedy_logprobs(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ images=images,
+ audios=audios,
+ lora_request=lora_request)
+ for prompts, images, audios in inputs
+ ]
+
+ with hf_runner(model, dtype=dtype) as hf_model:
+ hf_model.model.load_adapter(
+ vision_lora_path,
+ adapter_name="vision",
+ )
+ hf_processor = hf_model.processor
+ eos_token_id = hf_processor.tokenizer.eos_token_id
+ hf_outputs_per_case = [
+ hf_model.generate_greedy_logprobs_limit(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ images=images,
+ audios=audios,
+ eos_token_id=eos_token_id)
+ for prompts, images, audios in inputs
+ ]
+
+ for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
+ vllm_outputs_per_case):
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
+ )
+
+
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # No image
+ [],
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("dtype", [target_dtype])
+@pytest.mark.parametrize("max_model_len", [12800])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [10])
+def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
+ dtype: str, max_model_len: int, max_tokens: int,
+ num_logprobs: int) -> None:
+ images = [asset.pil_image for asset in image_assets]
+
+ inputs_per_image = [(
+ [prompt for _ in size_factors],
+ [rescale_image_size(image, factor) for factor in size_factors],
+ None,
+ ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
+
+ run_test(
+ hf_runner,
+ vllm_runner,
+ inputs_per_image,
+ model,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ mm_limit=1,
+ tensor_parallel_size=1,
+ )
+
+
+@large_gpu_test(min_gb=48)
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # No image
+ # [],
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("dtype", [target_dtype])
+@pytest.mark.parametrize("max_model_len", [25600])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [10])
+def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
+ size_factors, dtype: str, max_model_len: int,
+ max_tokens: int, num_logprobs: int) -> None:
+ images = [asset.pil_image for asset in image_assets]
+
+ inputs_per_case = [
+ (
+ [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
+ [[rescale_image_size(image, factor) for image in images]
+ for factor in size_factors],
+ None,
+ ),
+ ]
+
+ run_test(
+ hf_runner,
+ vllm_runner,
+ inputs_per_case,
+ model,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ mm_limit=2,
+ tensor_parallel_size=1,
+ )
+
+
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize("dtype", [target_dtype])
+@pytest.mark.parametrize("max_model_len", [12800])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [10])
+def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
+ max_model_len: int, max_tokens: int,
+ num_logprobs: int) -> None:
+
+ # use the example speech question so that the model outputs are reasonable
+ audio = librosa.load(speech_question, sr=16000)
+ image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
+
+ inputs_vision_speech = [
+ (
+ ["<|user|><|image|><|audio|><|end|><|assistant|>"],
+ [image],
+ [audio],
+ ),
+ ]
+
+ run_test(
+ hf_runner,
+ vllm_runner,
+ inputs_vision_speech,
+ model,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ mm_limit=1,
+ tensor_parallel_size=1,
+ )
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index fd5842523178..627c6715fbc7 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -41,12 +41,18 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
def _test_processing_correctness(
- model_id: str,
+ model_id_or_arch: str,
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
- model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
+ if model_id_or_arch in HF_EXAMPLE_MODELS.get_supported_archs():
+ # Use model architecture to get the default model id
+ model_info = HF_EXAMPLE_MODELS.get_hf_info(model_id_or_arch)
+ model_id = model_info.default
+ else:
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch)
+ model_id = model_id_or_arch
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
@@ -58,7 +64,7 @@ def _test_processing_correctness(
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="auto",
- revision=None,
+ revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
)
@@ -330,6 +336,28 @@ def test_processing_correctness(
)
+# Phi4MultimodalForCausalLM share same model repo with original format
+# Phi4MMForCausalLM, so we add it as a separate test case
+# Remove this test after conversion PR merged:
+# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70
+@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"])
+@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
+@pytest.mark.parametrize("num_batches", [32])
+@pytest.mark.parametrize("simplify_rate", [1.0])
+def test_processing_correctness_phi4_multimodal(
+ model_arch: str,
+ hit_rate: float,
+ num_batches: int,
+ simplify_rate: float,
+):
+ _test_processing_correctness(
+ model_arch,
+ hit_rate=hit_rate,
+ num_batches=num_batches,
+ simplify_rate=simplify_rate,
+ )
+
+
def _assert_inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 84ca0bc60003..574f0a3ff20f 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -425,6 +425,8 @@ def check_available_online(
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
+ "Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
+ revision="refs/pr/70"),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py
new file mode 100644
index 000000000000..432b707a6159
--- /dev/null
+++ b/vllm/model_executor/models/phi4_multimodal.py
@@ -0,0 +1,1455 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import (BatchFeature, Phi4MultimodalAudioConfig,
+ Phi4MultimodalConfig, Phi4MultimodalFeatureExtractor,
+ Phi4MultimodalImageProcessorFast)
+from transformers import Phi4MultimodalProcessor as Phi4MMProcessor
+from transformers.models.phi4_multimodal.modeling_phi4_multimodal import (
+ Phi4MultimodalAudioConvModule, Phi4MultimodalAudioNemoConvSubsampling,
+ Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor)
+
+from vllm.config import VllmConfig
+from vllm.distributed import (divide, get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
+from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs, NestedTensors)
+from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
+ ImageProcessorItems, ImageSize,
+ MultiModalDataItems, MultiModalDataParser)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils import is_list_of
+
+from .idefics2_vision_model import Idefics2VisionTransformer
+from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+# <|endoftext10|> (see vocab.json in hf model)
+_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
+# <|endoftext11|>
+_AUDIO_PLACEHOLDER_TOKEN_ID = 200011
+
+_AUDIO_MAX_SOUNDFILE_SIZE = 241_000
+
+
+def _get_padding_size(orig_width: int, orig_height: int, target_height: int,
+ target_width: int):
+ ratio_width = target_width / orig_width
+ ratio_height = target_height / orig_height
+
+ if ratio_width < ratio_height:
+ padding_width = 0
+ padding_height = target_height - int(orig_height * ratio_width)
+ else:
+ padding_width = target_width - int(orig_width * ratio_height)
+ padding_height = 0
+ return padding_height, padding_width
+
+
+class Phi4MMProjector(nn.Module):
+
+ def __init__(self, input_size: int, hidden_size: int):
+ super().__init__()
+ self.up = ColumnParallelLinear(input_size, hidden_size)
+ self.down = RowParallelLinear(hidden_size, hidden_size)
+ self.act = get_act_fn("gelu")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, _ = self.up(x)
+ x = self.act(x)
+ x, _ = self.down(x)
+ return x
+
+
+class Phi4MMImageEmbedding(nn.Module):
+ """Image embedding."""
+
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.layer_idx = config.vision_config.feature_layer
+ self.crop_size = config.vision_config.crop_size
+ self.image_dim_out = config.vision_config.hidden_size
+
+ n_patches = (config.vision_config.image_size //
+ config.vision_config.patch_size)
+ if n_patches % 2 != 0:
+ self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
+ n_patches += 1
+ self.num_img_tokens = (n_patches // 2)**2
+
+ num_hidden_layers = (config.vision_config.num_hidden_layers +
+ self.layer_idx +
+ 1 if self.layer_idx < 0 else self.layer_idx + 1)
+ self.img_processor = Idefics2VisionTransformer(
+ config.vision_config,
+ require_post_norm=False,
+ num_hidden_layers_override=num_hidden_layers)
+ self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.img_projection = Phi4MMProjector(self.image_dim_out,
+ config.hidden_size)
+ self.global_img_feature_extensor = nn.Parameter(
+ torch.zeros([1, 1, self.image_dim_out]))
+ self.sub_img_feature_extensor = nn.Parameter(
+ torch.zeros([1, 1, 1, self.image_dim_out]))
+
+ def get_img_features(
+ self,
+ img_embeds: torch.FloatTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ img_feature = self.img_processor(img_embeds,
+ patch_attention_mask=attention_mask)
+
+ patch_feature = img_feature
+ # reshape to 2D tensor
+ width = int(math.sqrt(patch_feature.size(1)))
+ patch_feature = patch_feature.view(-1, width, width,
+ patch_feature.size(-1))
+ # convert to NCHW
+ patch_feature = patch_feature.permute(0, 3, 1, 2)
+ if getattr(self, "img_processor_padding", None) is not None:
+ patch_feature = self.img_processor_padding(patch_feature)
+ patch_feature = self.image_token_compression(patch_feature)
+ # convert to NHWC
+ patch_feature = patch_feature.permute(0, 2, 3, 1)
+ patch_feature = patch_feature.view(
+ -1,
+ patch_feature.size(1) * patch_feature.size(2),
+ patch_feature.size(-1))
+ return patch_feature
+
+ def forward(
+ self,
+ image_pixel_values: torch.FloatTensor,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ image_pixel_values = image_pixel_values.to(
+ self.img_processor.embeddings.patch_embedding.weight.dtype)
+
+ target_device = self.img_projection.up.bias.device
+ target_dtype = self.img_projection.up.bias.dtype
+
+ batch_size = image_pixel_values.shape[0]
+
+ img_features = self.get_img_features(
+ image_pixel_values.flatten(0, 1),
+ attention_mask=image_attention_mask.flatten(0, 1).to(
+ dtype=bool, device=target_device),
+ )
+ base_feat_size = int(np.sqrt(img_features.shape[1]))
+ img_features = img_features.view(batch_size, -1, base_feat_size**2,
+ self.image_dim_out)
+ image_sizes = image_sizes.view(-1, 2)
+
+ output_imgs = []
+ for idx in range(batch_size):
+ height, width = image_sizes[idx]
+ height_ratio = height // self.crop_size
+ width_ratio = width // self.crop_size
+ area_ratio = height_ratio * width_ratio
+
+ global_img = img_features[idx, :1]
+ global_img = global_img.reshape(1, base_feat_size, base_feat_size,
+ self.image_dim_out).contiguous()
+ temporary_extensor = self.sub_img_feature_extensor.repeat(
+ 1, base_feat_size, 1, 1)
+ global_img = torch.cat([global_img, temporary_extensor],
+ dim=2).reshape(1, -1, self.image_dim_out)
+
+ sub_img = img_features[idx, 1:]
+ sub_img = sub_img[:area_ratio]
+ sub_img = (sub_img.reshape(
+ height_ratio, width_ratio, base_feat_size, base_feat_size,
+ self.image_dim_out).transpose(1, 2).reshape(
+ 1, height_ratio * base_feat_size,
+ width_ratio * base_feat_size,
+ self.image_dim_out).contiguous())
+
+ if image_attention_mask is not None:
+ reshaped_image_attention_mask = (
+ image_attention_mask[idx, 1:area_ratio + 1,
+ 0::2, 0::2].reshape(
+ height_ratio, width_ratio,
+ base_feat_size,
+ base_feat_size).transpose(
+ 1, 2).reshape(
+ 1, height_ratio *
+ base_feat_size,
+ width_ratio *
+ base_feat_size))
+ useful_height = int(
+ reshaped_image_attention_mask[0, :, 0].sum().item())
+ useful_width = int(
+ reshaped_image_attention_mask[0, 0, :].sum().item())
+ sub_img = sub_img[:, :useful_height, :useful_width]
+ temporary_extensor = self.sub_img_feature_extensor.repeat(
+ 1, useful_height, 1, 1)
+ else:
+ temporary_extensor = self.sub_img_feature_extensor.repeat(
+ 1, height_ratio * base_feat_size, 1, 1)
+
+ sub_img = torch.cat([sub_img, temporary_extensor],
+ dim=2).reshape(1, -1, self.image_dim_out)
+
+ # Merge global and sub
+ output_imgs.append(
+ torch.cat(
+ [sub_img, self.global_img_feature_extensor, global_img],
+ dim=1))
+
+ img_set_tensor = []
+ for output_img in output_imgs:
+ output_img = output_img.to(device=target_device,
+ dtype=target_dtype)
+ img_feature_proj = self.img_projection(output_img)
+ img_set_tensor.append(img_feature_proj.flatten(0, 1))
+
+ return img_set_tensor
+
+
+class Phi4MultimodalAudioMLP(nn.Module):
+
+ def __init__(
+ self,
+ config: Phi4MultimodalAudioConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.act_fn = MulAndSilu()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ config.hidden_size, [config.intermediate_size] * 2,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, _ = self.gate_up_proj(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states, _ = self.down_proj(hidden_states)
+ return hidden_states
+
+
+class Phi4MultimodalAudioAttention(nn.Module):
+
+ def __init__(
+ self,
+ config: Phi4MultimodalAudioConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.total_num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.total_num_heads
+ if self.head_dim * self.total_num_heads != self.embed_dim:
+ raise ValueError(
+ "embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size=self.embed_dim,
+ head_size=self.head_dim,
+ total_num_heads=self.total_num_heads,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=self.embed_dim,
+ output_size=self.embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ )
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.num_heads = divide(self.total_num_heads, self.tp_size)
+
+ def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
+ start_idx = self.num_heads * self.tp_rank
+ end_idx = self.num_heads * (self.tp_rank + 1)
+ return attention_mask[:, start_idx:end_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv_states, _ = self.qkv_proj(hidden_states)
+ query, key, value = qkv_states.chunk(3, dim=-1)
+
+ bsz, seq_len, _ = query.size()
+ query = query.view(bsz, seq_len, self.num_heads, self.head_dim)
+ key = key.view(bsz, seq_len, self.num_heads, self.head_dim)
+ value = value.view(bsz, seq_len, self.num_heads, self.head_dim)
+ query, key, value = (x.transpose(1, 2) for x in (query, key, value))
+
+ attention_mask = self.split_attn_mask(attention_mask)
+ out = F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ scale=self.scale,
+ attn_mask=attention_mask,
+ )
+ out = out.transpose(1, 2).reshape(bsz, seq_len, -1)
+
+ attn_output, _ = self.o_proj(out)
+
+ return attn_output
+
+
+class Phi4MultimodalAudioConformerEncoderLayer(nn.Module):
+
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+
+ self.feed_forward_in = Phi4MultimodalAudioMLP(config)
+ self.self_attn = Phi4MultimodalAudioAttention(config)
+ self.conv = Phi4MultimodalAudioConvModule(config)
+ self.feed_forward_out = Phi4MultimodalAudioMLP(config)
+ self.layer_norm_att = nn.LayerNorm(config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states)
+ hidden_states = self.layer_norm_att(residual)
+
+ hidden_states = residual + self.self_attn(hidden_states,
+ attention_mask)
+ hidden_states = hidden_states + self.conv(hidden_states)
+ hidden_states = hidden_states + 0.5 * self.feed_forward_out(
+ hidden_states)
+
+ out = self.layer_norm(hidden_states)
+
+ return out
+
+
+class Phi4MMAudioMeanVarianceNormLayer(nn.Module):
+ """Mean/variance normalization layer.
+
+ Will subtract mean and multiply input by inverted standard deviation.
+ Typically used as a very first layer in a model.
+
+ Args:
+ input_size: int
+ layer input size.
+ """
+
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.global_mean = nn.Parameter(torch.zeros(config.input_size))
+ self.global_invstd = nn.Parameter(torch.ones(config.input_size))
+
+ def forward(self, input_: torch.Tensor) -> torch.Tensor:
+ """MeanVarianceNormLayer Forward
+
+ Args:
+ input_: torch.Tensor
+ input tensor.
+ """
+ return (input_ - self.global_mean) * self.global_invstd
+
+
+class Phi4MultimodalAudioModel(nn.Module):
+
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config)
+ self.embed = Phi4MultimodalAudioNemoConvSubsampling(config)
+ self.relative_attention_bias_layer = (
+ Phi4MultimodalAudioRelativeAttentionBias(config))
+ self.encoders = nn.ModuleList([
+ Phi4MultimodalAudioConformerEncoderLayer(config)
+ for _ in range(config.num_blocks)
+ ])
+
+ def _streaming_mask(
+ self,
+ seq_len: int,
+ batch_size: int,
+ chunk_size: int,
+ left_chunk: int,
+ ):
+ # Create mask matrix for streaming
+ # S stores start index. if chunksize is 18, s is [0,18,36,....]
+ chunk_start_idx = np.arange(0, seq_len, chunk_size)
+
+ enc_streaming_mask = (adaptive_enc_mask(
+ seq_len, chunk_start_idx,
+ left_window=left_chunk).unsqueeze(0).expand([batch_size, -1, -1]))
+ return enc_streaming_mask
+
+ def forward_embeddings(
+ self,
+ hidden_states: torch.Tensor,
+ masks: torch.Tensor,
+ ):
+ """Forwarding the inputs through the top embedding layers"""
+ seq_len = math.ceil(hidden_states.shape[1] /
+ self.config.time_reduction)
+ if seq_len <= 0:
+ raise ValueError(
+ f"Sequence length after time reduction is invalid: {seq_len}."
+ "Your input feature is too short.")
+
+ batch_size = hidden_states.shape[0]
+
+ enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
+ self.config.chunk_size,
+ self.config.left_chunk)
+ enc_streaming_mask = enc_streaming_mask.to(hidden_states.device)
+
+ hidden_states, masks = self.embed(hidden_states, masks)
+
+ streaming_mask = enc_streaming_mask
+ if streaming_mask is not None and masks is not None:
+ hs_mask = masks & streaming_mask
+ elif masks is not None:
+ hs_mask = masks
+ else:
+ hs_mask = streaming_mask
+
+ return hidden_states, hs_mask, masks
+
+ def calculate_hs_mask(self, hidden_states: torch.Tensor,
+ device: torch.device, mask: torch.Tensor):
+ max_audio_length = hidden_states.shape[1]
+ batch_size = hidden_states.shape[0]
+ enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
+ self.config.chunk_size,
+ self.config.left_chunk)
+ enc_streaming_mask = enc_streaming_mask.to(device)
+ if mask is None:
+ return enc_streaming_mask
+
+ feature_lens = mask.sum(1)
+ padding_length = feature_lens
+ pad_mask = torch.arange(0, max_audio_length, device=device).expand(
+ padding_length.size(0), -1) < padding_length.unsqueeze(1)
+ pad_mask = pad_mask.unsqueeze(1)
+ pad_mask = pad_mask & enc_streaming_mask
+ return pad_mask
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ mask: Optional[torch.Tensor] = None):
+ hidden_states = self.encoder_embedding(hidden_states)
+ hidden_states, hs_mask, mask = self.forward_embeddings(
+ hidden_states, mask)
+
+ unfolded = False
+ bs, seq_len, _ = hidden_states.shape
+ max_seq_len = 500 # maximum position for absolute positional encoding
+ if seq_len > max_seq_len:
+ # audio sequence is longer than max_seq_len,
+ # unfold it into chunks of max_seq_len
+ unfolded = True
+ # the unfold op will drop residual frames,
+ # pad it to the multiple of max_seq_len
+ if seq_len % max_seq_len > 0:
+ chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
+ else:
+ chunk_pad_size = 0
+ if chunk_pad_size > 0:
+ hidden_states_pad = F.pad(hidden_states,
+ (0, 0, 0, chunk_pad_size),
+ "constant", 0)
+ hidden_states = hidden_states_pad.to(hidden_states.device)
+
+ hidden_states = unfold_tensor(hidden_states, max_seq_len)
+ masks_unfold = None
+ if mask is not None:
+ # revise hs_mask here because the previous calculated hs_mask
+ # did not consider extra pad
+ subsampled_pad_mask = mask.squeeze(
+ 1) # [bz, subsampled_unmask_seq_len]
+ extra_padded_subsamlped_pad_mask = F.pad(
+ subsampled_pad_mask, (0, chunk_pad_size), "constant",
+ False) # extra padding to the pad mask
+ extra_padded_subsamlped_pad_mask = (
+ extra_padded_subsamlped_pad_mask.unsqueeze(-1).float())
+ masks_unfold = unfold_tensor(
+ extra_padded_subsamlped_pad_mask, max_seq_len
+ ) # unfold the pad mask like we did to the input tensor
+ masks_unfold = masks_unfold.squeeze(
+ -1).bool() # unfold op does not support bool tensor
+ hs_mask = self.calculate_hs_mask(
+ hidden_states, hidden_states.device, masks_unfold
+ ) # calculate hs_mask based on the unfolded pad mask
+
+ relative_attention_bias = self.relative_attention_bias_layer(
+ hidden_states)
+ attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
+
+ for layer in self.encoders:
+ hidden_states = layer(hidden_states, attention_mask)
+
+ if unfolded:
+ embed_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.reshape(bs, -1, embed_dim)
+ # if we ever padded before unfolding, we need to remove the padding
+ if chunk_pad_size > 0:
+ hidden_states = hidden_states[:, :-chunk_pad_size, :]
+
+ return hidden_states
+
+
+class Phi4MMAudioEmbedding(nn.Module):
+
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.layer_idx = config.audio_config.feature_layer
+
+ self.encoder = Phi4MultimodalAudioModel(config.audio_config)
+
+ audio_config = config.audio_config
+ proj_input_size = (audio_config.hidden_size *
+ audio_config.downsample_rate)
+ self.vision_speech_projection = Phi4MMProjector(
+ proj_input_size, config.hidden_size)
+ self.speech_projection = Phi4MMProjector(proj_input_size,
+ config.hidden_size)
+
+ def get_projection(
+ self,
+ audio_projection_mode: Literal["speech", "vision"],
+ ) -> Phi4MMProjector:
+ if audio_projection_mode == "speech":
+ return self.speech_projection
+ elif audio_projection_mode == "vision":
+ return self.vision_speech_projection
+
+ def forward(
+ self,
+ audio_input_features: torch.FloatTensor,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ audio_projection_mode="speech",
+ ) -> torch.FloatTensor:
+
+ audio_projection = self.get_projection(audio_projection_mode)
+
+ target_device = audio_projection.up.bias.device
+ target_dtype = audio_projection.up.bias.dtype
+
+ audio_input_features = audio_input_features.to(device=target_device,
+ dtype=target_dtype)
+
+ audio_encoder_hidden_states = self.encoder(audio_input_features,
+ audio_attention_mask)
+ audio_embeds = audio_projection(audio_encoder_hidden_states)
+
+ return audio_embeds.flatten(0, 1)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Phi4MMImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ Shape:
+ `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
+
+ Note that `num_patches` may be different per batch and image,
+ in which case the data is passed as a list instead of a batched tensor.
+ """
+
+ image_sizes: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(height, width)` format.
+ """
+
+ num_img_tokens: list[int]
+ """Shape: `(batch_size * num_images)`"""
+
+ image_attention_mask: torch.Tensor
+ """Shape: `(batch_size * num_images, H_mask, W_mask)`"""
+
+
+class Phi4MMImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: Union[torch.Tensor, list[torch.Tensor]]
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ """
+
+
+class Phi4MMAudioFeatureInputs(TypedDict):
+ type: Literal["audio_features"]
+ data: Union[torch.Tensor, list[torch.Tensor]]
+ """Shape: `(batch_size * num_audios, 80, M)"""
+
+
+class Phi4MMAudioEmbeddingInputs(TypedDict):
+ type: Literal["audio_embeds"]
+ data: NestedTensors
+ """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
+
+
+Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
+Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
+
+
+def cat_with_pad(tensors, dim, padding_value=0):
+ """
+ cat along dim, while pad to max for all other dims
+ """
+ ndim = tensors[0].dim()
+ assert all(
+ t.dim() == ndim for t in
+ tensors[1:]), "All tensors must have the same number of dimensions"
+
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
+ output = tensors[0].new_full(out_size, padding_value)
+
+ index = 0
+ for t in tensors:
+ # Create a slice list where every dimension except dim is full slice
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
+ # Update only the concat dimension slice
+ slices[dim] = slice(index, index + t.shape[dim])
+
+ output[slices] = t
+ index += t.shape[dim]
+
+ return output
+
+
+class Phi4MMProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self) -> Phi4MultimodalConfig:
+ return self.ctx.get_hf_config(Phi4MultimodalConfig)
+
+ def get_hf_processor(
+ self,
+ *,
+ dynamic_hd: Optional[int] = None,
+ **kwargs: object,
+ ) -> Phi4MMProcessor:
+ if dynamic_hd is not None:
+ kwargs["dynamic_hd"] = dynamic_hd
+
+ return self.ctx.get_hf_processor(**kwargs)
+
+ def get_feature_extractor(self) -> Phi4MultimodalFeatureExtractor:
+ return self.get_hf_processor().audio_processor
+
+ def get_image_processor(
+ self,
+ processor: Optional[Phi4MMProcessor] = None,
+ ) -> Phi4MultimodalImageProcessorFast:
+ if processor is None:
+ processor = self.get_hf_processor()
+ return processor.image_processor
+
+ def get_dynamic_hd(
+ self,
+ processor: Optional[Phi4MMProcessor] = None,
+ ) -> int:
+ return self.get_image_processor(processor).dynamic_hd
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"audio": None, "image": None}
+
+ def _find_target_aspect_ratio(
+ self,
+ orig_width: int,
+ orig_height: int,
+ image_size: int,
+ max_num: int,
+ min_num: int,
+ ):
+ w_crop_num = math.ceil(orig_width / float(image_size))
+ h_crop_num = math.ceil(orig_height / float(image_size))
+ if w_crop_num * h_crop_num > max_num:
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set((i, j) for i in range(1, max_num + 1)
+ for j in range(1, max_num + 1)
+ if i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ image_processor = self.get_image_processor()
+ target_aspect_ratio = image_processor.find_closest_aspect_ratio(
+ aspect_ratio,
+ target_ratios,
+ orig_width,
+ orig_height,
+ )
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ else:
+ target_width = image_size * w_crop_num
+ target_height = image_size * h_crop_num
+ target_aspect_ratio = (w_crop_num, h_crop_num)
+ return target_aspect_ratio, target_height, target_width
+
+ def _compute_num_image_tokens(
+ self,
+ orig_width: int,
+ orig_height: int,
+ dynamic_hd_size: int,
+ vit_image_size: int,
+ vit_patch_size: int,
+ token_compression_factor: int = 2,
+ ):
+ """
+ compute the number of tokens an image is expected to take up considering
+ the image encoder architecture and exclude output features containing
+ only padding pixels
+
+ for siglip, vit_image_size=448, vit_patch_size=14, so output will be
+ 32x32 feature map
+ NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
+ """
+ assert vit_image_size % vit_patch_size == 0, (
+ "vit_image_size must be divisible by vit_patch_size")
+ assert (vit_image_size // vit_patch_size %
+ token_compression_factor == 0), (
+ "vit_image_size // vit_patch_size must be divisible by "
+ "token_compression_factor")
+
+ target_aspect_ratio, target_height, target_width = (
+ self._find_target_aspect_ratio(orig_width,
+ orig_height,
+ vit_image_size,
+ dynamic_hd_size,
+ min_num=1))
+ assert target_aspect_ratio[0] * vit_image_size == target_width, (
+ f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}")
+ assert target_aspect_ratio[1] * vit_image_size == target_height, (
+ f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}")
+ assert (target_height % vit_image_size == 0
+ and target_width % vit_image_size == 0)
+
+ padding_height, padding_width = _get_padding_size(
+ orig_width, orig_height, target_height, target_width)
+ assert padding_width == 0 or padding_height == 0, \
+ "padding_width or padding_height must be 0"
+
+ target_feat_width = target_width // vit_patch_size
+ target_feat_height = target_height // vit_patch_size
+ if padding_width >= vit_patch_size:
+ assert padding_height == 0, "padding_height not 0"
+ non_pad_feat_width = target_feat_width - math.floor(
+ padding_width / vit_patch_size)
+ non_pad_feat_height = target_feat_height
+ elif padding_height >= vit_patch_size:
+ assert padding_width == 0, "padding_width not 0"
+ non_pad_feat_height = target_feat_height - math.floor(
+ padding_height / vit_patch_size)
+ non_pad_feat_width = target_feat_width
+ else:
+ # small padding shorter than a vit patch
+ non_pad_feat_width = target_feat_width
+ non_pad_feat_height = target_feat_height
+
+ feat_width = non_pad_feat_width // token_compression_factor
+ feat_height = non_pad_feat_height // token_compression_factor
+ # NOTE it's possible that the non-padding feature is not divisible
+ if non_pad_feat_width % token_compression_factor != 0:
+ feat_width += 1
+ if non_pad_feat_height % token_compression_factor != 0:
+ feat_height += 1
+ num_hd_patch_tokens = feat_width * feat_height
+ num_hd_newline_tokens = feat_height
+ vit_feature_size = vit_image_size // vit_patch_size
+ num_global_image_tokens = (vit_feature_size //
+ token_compression_factor)**2
+ num_sep_tokens = 1
+ num_global_image_newline_tokens = \
+ vit_feature_size // token_compression_factor
+
+ return (num_global_image_tokens + num_sep_tokens +
+ num_hd_patch_tokens + num_hd_newline_tokens +
+ num_global_image_newline_tokens)
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[Phi4MMProcessor] = None,
+ ) -> int:
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ vit_image_size = vision_config.image_size
+ vit_patch_size = vision_config.patch_size
+
+ dynamic_hd_size = self.get_dynamic_hd(processor=processor)
+
+ # we use default `token_compression_factor=2`,
+ # since it's not in HF vision config.
+ image_num_tokens = self._compute_num_image_tokens(
+ image_width,
+ image_height,
+ dynamic_hd_size=dynamic_hd_size,
+ vit_image_size=vit_image_size,
+ vit_patch_size=vit_patch_size,
+ )
+
+ return image_num_tokens
+
+ def get_image_size_with_most_features(
+ self,
+ processor: Optional[Phi4MMProcessor] = None,
+ ) -> ImageSize:
+ vit_image_size = self.get_hf_config().vision_config.image_size
+
+ max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
+ return ImageSize(height=max_side, width=vit_image_size)
+
+ def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
+ """
+ Compute the output size of the `extract_features` method.
+
+ Args:
+ audio_len (int): Length of the input waveform in samples.
+ sr (float): Sampling rate of the waveform, either 16000 or 8000.
+
+ Returns:
+ tuple (int, int): Output size as (T, D), where:
+ T: Number of time frames.
+ D: Number of Mel filterbank bins (80).
+ """
+
+ # Resample to 16000 or 8000 if needed
+ if sr > 16000:
+ audio_len //= sr // 16000
+ elif 8000 <= sr < 16000:
+ # We'll resample to 16K from 8K
+ audio_len *= 2
+ elif sr < 8000:
+ raise RuntimeError(f"Unsupported sample rate {sr}")
+
+ # Spectrogram parameters for 16 kHz
+ win_length = 400 # Frame length in samples
+ hop_length = 160 # Frame shift in samples
+
+ # Calculate number of frames (T)
+ num_frames = (audio_len - win_length) // hop_length + 1
+ if num_frames < 1:
+ raise ValueError("Waveform too short for given parameters.")
+
+ # Return time frames (T)
+ return num_frames
+
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
+ """
+ Compute the size of audio embeddings from the number of audio frames.
+ """
+ # `_compute_audio_embed_size` in audio_processor use torch for
+ # computation, therefore we re-implement it to use pythonic
+ # numeric computation to avoid extra tensor conversion.
+ audio_processor = self.get_feature_extractor()
+ audio_compression_rate = audio_processor.audio_compression_rate
+ audio_downsample_rate = audio_processor.audio_downsample_rate
+
+ integer = audio_frames // audio_compression_rate
+ remainder = audio_frames % audio_compression_rate
+ result = integer + int(remainder > 0)
+
+ integer = result // audio_downsample_rate
+ remainder = result % audio_downsample_rate
+ result = integer + int(remainder > 0) # qformer compression
+
+ return result
+
+
+class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_audios = mm_counts.get("audio", 0)
+ num_images = mm_counts.get("image", 0)
+
+ tokenizer = self.info.get_tokenizer()
+ image_tokens: str = tokenizer.image_token * num_images
+ audio_tokens: str = tokenizer.audio_token * num_audios
+
+ return image_tokens + audio_tokens
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_audios = mm_counts.get("audio", 0)
+ num_images = mm_counts.get("image", 0)
+
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ "audio":
+ self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
+ num_audios=num_audios),
+ }
+
+ return mm_data
+
+
+class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
+
+ def _get_data_parser(self) -> MultiModalDataParser:
+ feature_extractor = self.info.get_feature_extractor()
+ return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if not mm_data:
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ audio_data = mm_data.pop("audios", [])
+ if audio_data:
+ mm_data['audio'] = audio_data
+
+ processed_outputs = super()._call_hf_processor(prompt, mm_data,
+ mm_kwargs, tok_kwargs)
+
+ if "image_pixel_values" in processed_outputs:
+ num_img_tokens = [
+ self.info.get_num_image_tokens(image_width=img_size[0],
+ image_height=img_size[1])
+ for img_size in processed_outputs["image_sizes"]
+ ]
+ processed_outputs["num_img_tokens"] = num_img_tokens
+
+ if audio_data:
+ audio_features = processed_outputs['audio_input_features']
+ sr = self.info.get_feature_extractor().sampling_rate
+ feature_sizes = [
+ self.info.get_audio_num_frames(len(audio), sr)
+ for audio in audio_data
+ ]
+ processed_outputs['audio_input_features'] = [
+ audio_features[idx, :size]
+ for idx, size in enumerate(feature_sizes)
+ ]
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ image_pixel_values=MultiModalFieldConfig.batched("image"),
+ image_attention_mask=MultiModalFieldConfig.batched("image"),
+ image_sizes=MultiModalFieldConfig.batched("image"),
+ num_img_tokens=MultiModalFieldConfig.batched("image"),
+ audio_input_features=MultiModalFieldConfig.batched("audio"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ tokenizer = self.info.get_tokenizer()
+ image_token_id = tokenizer.vocab[tokenizer.image_token]
+ audio_token_id = tokenizer.vocab[tokenizer.audio_token]
+
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ audio_processor = self.info.get_feature_extractor()
+
+ def get_image_replacement_phi4mm(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ num_image_tokens = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ num_image_tokens = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ image_tokens = [image_token_id] * num_image_tokens
+
+ return image_tokens
+
+ def get_audio_replacement_phi4mm(item_idx: int):
+ audios = mm_items.get_items("audio", AudioProcessorItems)
+ # TODO(Isotr0py): support embedding inputs
+ audio_len = audios.get_audio_length(item_idx)
+ audio_frames = self.info.get_audio_num_frames(
+ audio_len, audio_processor.sampling_rate)
+ audio_embed_size = self.info._compute_audio_embed_size(
+ audio_frames)
+
+ audio_tokens = [audio_token_id] * audio_embed_size
+
+ return audio_tokens
+
+ return [
+ PromptReplacement(
+ modality="audio",
+ target=[audio_token_id],
+ replacement=get_audio_replacement_phi4mm,
+ ),
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=get_image_replacement_phi4mm,
+ ),
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Phi4MMMultiModalProcessor,
+ info=Phi4MMProcessingInfo,
+ dummy_inputs=Phi4MMDummyInputsBuilder,
+)
+class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
+ """
+ Implements the Phi-4-multimodal-instruct model in vLLM.
+ """
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "qkv_proj",
+ ],
+ "gate_up_proj": [
+ "gate_up_proj",
+ ],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # Multimodal embedding
+ "model.embed_tokens_extend.": "",
+ # LLM backbone
+ "model.": "language_model.model.",
+ },
+ orig_to_new_substr={
+ # projection
+ ".img_projection_": ".img_projection.",
+ ".up_proj_for_speech.": ".speech_projection.up.",
+ ".up_proj_for_vision_speech.": ".vision_speech_projection.up.",
+ ".down_proj_for_speech.": ".speech_projection.down.",
+ ".down_proj_for_vision_speech.": ".vision_speech_projection.down.",
+ },
+ )
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return "<|image|>"
+ if modality.startswith("audio"):
+ return "<|audio|>"
+
+ raise ValueError("Only image or audio modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ # TODO: Optionally initializes these for supporting input embeddings.
+ self.image_embed = Phi4MMImageEmbedding(
+ config,
+ # prefix=maybe_prefix(prefix, "image_embed"),
+ )
+ self.audio_embed = Phi4MMAudioEmbedding(
+ config,
+ # prefix=maybe_prefix(prefix, "audio_embed"),
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Phi3ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def _parse_and_validate_audio_input(
+ self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
+ """
+ Parse and validate the audio input to the model. This handles both
+ audio features and audio embeddings, but only the former is used for
+ now.
+
+ Args:
+ kwargs (object): Keyword arguments.
+
+ Returns:
+ Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
+ """
+ audio_features = kwargs.pop("audio_input_features", None)
+ audio_embeds = kwargs.pop("audio_embeds", None)
+
+ if audio_features is None and audio_embeds is None:
+ return None
+
+ if audio_features is not None:
+ if not isinstance(audio_features, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio features. "
+ f"Got type: {type(audio_features)}")
+
+ return Phi4MMAudioFeatureInputs(type="audio_features",
+ data=flatten_bn(audio_features))
+
+ if audio_embeds is not None:
+ if not isinstance(audio_embeds, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio embeds. "
+ f"Got type: {type(audio_embeds)}")
+
+ return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
+ data=audio_embeds)
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _process_audio_input(self, audio_input: Phi4MMAudioInputs,
+ audio_projection_mode: str) -> NestedTensors:
+ """
+ Create the audio embeddings from the audio input, where the audio input
+ is pairs of audio features and audio embed lengths. The audio input is
+ created by `input_mapper_for_phi4mm_audio`.
+
+ Args:
+ audio_input (Phi4MMAudioInputs): Audio input.
+
+ Returns:
+ NestedTensors: Audio embeddings
+ """
+ if audio_input["type"] == "audio_embeds":
+ return audio_input["data"]
+
+ audio_features = audio_input["data"]
+ # (e.g. multiple examples) and the second dim is the multi-audio dim
+ # (e.g. multiple audios in the same example)
+
+ dtype = next(self.audio_embed.parameters()).dtype
+ audio_embeds = [
+ self.audio_embed(
+ features.unsqueeze(0).to(dtype),
+ audio_projection_mode=audio_projection_mode,
+ ) for features in audio_features
+ ]
+ return audio_embeds
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
+ image_pixel_values: NestedTensors = kwargs.get("image_pixel_values")
+ if image_pixel_values is None:
+ return None
+
+ image_sizes = kwargs.get("image_sizes")
+ image_attention_mask = kwargs.get("image_attention_mask")
+ num_img_tokens = kwargs.get("num_img_tokens")
+ assert image_sizes is not None and image_attention_mask is not None\
+ and num_img_tokens is not None, "Missing image inputs"
+
+ if is_list_of(image_pixel_values, torch.Tensor):
+ assert all(p.dim() == 5
+ for p in image_pixel_values), "Incorrect image inputs"
+ # list len is batch_size.
+ # each tensor has dimension: num_img_per_example, num_hd_patches,
+ # channels, height, width.
+ # need to pad along num_hd_patches.
+ # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
+ image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
+ elif isinstance(image_pixel_values, torch.Tensor):
+ # dimension: batch_size, num_img_per_example, num_hd_patches,
+ # channels, height, width.
+ # we flatten first 2 dims to make it a single large batch for
+ # SigLIP Encoder.
+ assert image_pixel_values.dim() == 6, "Incorrect image inputs"
+ image_pixel_values = image_pixel_values.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect image_pixel_values inputs")
+
+ if isinstance(image_attention_mask, list):
+ image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
+ elif isinstance(image_attention_mask, torch.Tensor):
+ image_attention_mask = image_attention_mask.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ if isinstance(image_sizes, list):
+ image_sizes = torch.cat(image_sizes, dim=0)
+ elif isinstance(image_sizes, torch.Tensor):
+ image_sizes = image_sizes.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ if isinstance(num_img_tokens, list):
+ num_img_tokens = [
+ n for num_tensor in num_img_tokens
+ for n in num_tensor.tolist()
+ ]
+ elif isinstance(num_img_tokens, torch.Tensor):
+ num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ return Phi4MMImagePixelInputs(
+ type="pixel_values",
+ data=image_pixel_values,
+ image_sizes=image_sizes,
+ image_attention_mask=image_attention_mask,
+ num_img_tokens=num_img_tokens,
+ )
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("image_pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("audio_input_features",
+ "audio_embeds") and "audios" not in modalities:
+ modalities["audios"] = self._parse_and_validate_audio_input(
+ **kwargs)
+
+ return modalities
+
+ def _process_image_input(
+ self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]:
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ dtype = next(self.image_embed.parameters()).dtype
+ pixel_values = image_input['data'].to(dtype)
+ image_sizes = image_input['image_sizes']
+ image_attention_mask = image_input['image_attention_mask']
+ image_embeds = self.image_embed(pixel_values, image_sizes,
+ image_attention_mask)
+ return image_embeds
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return None
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ audio_projection_mode = 'speech'
+ for modality in modalities:
+ # make sure process images first
+ if modality == "images":
+ audio_projection_mode = "vision"
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += tuple(vision_embeddings)
+ if modality == "audios":
+ audio_input = modalities["audios"]
+ audio_embeddings = self._process_audio_input(
+ audio_input, audio_projection_mode=audio_projection_mode)
+ multimodal_embeddings += tuple(audio_embeddings)
+
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
+ return inputs_embeds
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: torch.Tensor,
+ image_input: Optional[Phi4MMImagePixelInputs] = None,
+ audio_input: Optional[Phi4MMAudioFeatureInputs] = None,
+ ) -> torch.Tensor:
+ audio_projection_mode = 'speech'
+ inputs_embeds = self.get_input_embeddings(input_ids)
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID,
+ )
+ audio_projection_mode = 'vision'
+
+ if audio_input is not None:
+ audio_embeds = self._process_audio_input(
+ audio_input, audio_projection_mode=audio_projection_mode)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ audio_embeds,
+ placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> torch.Tensor:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
+
+ if image_input is None and audio_input is None:
+ inputs_embeds = None
+ else:
+ inputs_embeds = self.get_input_embeddings_v0(
+ input_ids,
+ image_input=image_input,
+ audio_input=audio_input)
+ input_ids = None
+
+ hidden_states = self.language_model(
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model.",
+ connector=[
+ "img_projection", "vision_speech_projection",
+ "speech_projection"
+ ],
+ tower_model=["image_embed", "audio_embed"],
+ )
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index fafb6a704383..f33a69235d46 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -220,6 +220,8 @@
"Ovis": ("ovis", "Ovis"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
+ "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
+ "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
@@ -228,7 +230,6 @@
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
- "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 25dd71d877fb..24ddd35abea6 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -295,7 +295,7 @@ def cached_tokenizer_from_config(
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
- tokenizer_revision=model_config.tokenizer_revision,
+ revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)