From 74ce1e811163679e50d8e5cc99768344d1da8baf Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 6 Mar 2025 08:31:22 +0000 Subject: [PATCH 1/3] Done Signed-off-by: Jee Jee Li --- examples/offline_inference/vision_language.py | 39 +++++++++++++++++++ vllm/model_executor/models/phi4mm.py | 18 +++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 270c0f59cc58..9b424da65e8d 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -8,11 +8,13 @@ """ import random +from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset +from vllm.lora.request import LoRARequest from vllm.utils import FlexibleArgumentParser # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -519,6 +521,42 @@ def run_phi3v(questions: list[str], modality: str): return llm, prompts, stop_token_ids +# Phi-4-multimodal-instruct +def run_phi4mm(questions: list[str], modality: str): + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + only show how to process image inputs, as the processing logic for audio + inputs follows a similar method. + """ + + assert modality == "image" + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") + vision_lora_path = model_path + "/vision-lora" + prompts = [ + f"<|user|><|image_1|>{question}<|end|><|assistant|>" + for question in questions + ] + llm = LLM( + model=model_path, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=2, + enable_lora=True, + max_lora_rank=320, + lora_extra_vocab_size=0, + ) + lora_request = LoRARequest("vision", 1, vision_lora_path) + # To maintain code compatibility in this script, we add LoRA here. + llm.llm_engine.add_lora(lora_request=lora_request) + # You can also add LoRA using: + # llm.generate(prompts, lora_request=lora_request,...) + + stop_token_ids = None + return llm, prompts, stop_token_ids + + # Pixtral HF-format def run_pixtral_hf(questions: list[str], modality: str): assert modality == "image" @@ -644,6 +682,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str): "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, + "phi4_mm": run_phi4mm, "pixtral_hf": run_pixtral_hf, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 27ae9bcca2e4..6f5ea5af6c0a 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalInputs, NestedTensors @@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in VLLM. """ - # LoRA specific attributes packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): "gate_up_proj", ], } - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj" - ] - # Phi4MMForCausalLM does not apply LoRA to the embedding layer. - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1801,3 +1795,13 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model.", + connector=["audio_projection_for_vision", "audio_projection"], + tower_model=["vision_encoder", "embed_tokens_extend"], + ) \ No newline at end of file From 615110c37d6667e871f86d1415afafd9c429b68c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 6 Mar 2025 15:50:55 +0000 Subject: [PATCH 2/3] Init add audio example Signed-off-by: Jee Jee Li --- examples/offline_inference/audio_language.py | 38 +++++++++++++++++++ examples/offline_inference/vision_language.py | 9 ++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 1ceec026b319..d0fa56d4bbf1 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -6,10 +6,14 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import os + +from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset +from vllm.lora.request import LoRARequest from vllm.utils import FlexibleArgumentParser audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] @@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int): return llm, prompt, stop_token_ids +# Phi-4-multimodal-instruct +def run_phi4mm(questions: str, audio_count: int): + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + show how to process audio inputs. + """ + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") + speech_lora_path = os.path.join(model_path, "speech-lora") + prompts = [ + f"<|user|><|audio_1|>{question}<|end|><|assistant|>" + for question in questions + ] + llm = LLM( + model=model_path, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=2, + enable_lora=True, + max_lora_rank=320, + lora_extra_vocab_size=0, + ) + lora_request = LoRARequest("speech", 1, speech_lora_path) + # To maintain code compatibility in this script, we add LoRA here. + llm.llm_engine.add_lora(lora_request=lora_request) + # You can also add LoRA using: + # llm.generate(prompts, lora_request=lora_request,...) + + stop_token_ids = None + return llm, prompts, stop_token_ids + + # Qwen2-Audio def run_qwen2_audio(question: str, audio_count: int): model_name = "Qwen/Qwen2-Audio-7B-Instruct" @@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int): model_example_map = { "minicpmo": run_minicpmo, + "phi4_mm": run_phi4mm, "qwen2_audio": run_qwen2_audio, "ultravox": run_ultravox, "whisper": run_whisper, diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 9b424da65e8d..716c31b96ed1 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -6,6 +6,7 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import os import random from huggingface_hub import snapshot_download @@ -525,15 +526,13 @@ def run_phi3v(questions: list[str], modality: str): def run_phi4mm(questions: list[str], modality: str): """ Phi-4-multimodal-instruct supports both image and audio inputs. Here, we - only show how to process image inputs, as the processing logic for audio - inputs follows a similar method. + show how to process image inputs. """ - assert modality == "image" + model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. - model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") - vision_lora_path = model_path + "/vision-lora" + vision_lora_path = os.path.join(model_path, "vision-lora") prompts = [ f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions From 69df8b212aa5af9a182422f5064d00b5127e0a4c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 7 Mar 2025 10:14:58 +0000 Subject: [PATCH 3/3] Done Signed-off-by: Jee Jee Li --- examples/offline_inference/audio_language.py | 10 ++--- .../vision_language_multi_image.py | 44 +++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index d0fa56d4bbf1..4aa233211b0b 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -61,14 +61,14 @@ def run_phi4mm(questions: str, audio_count: int): Phi-4-multimodal-instruct supports both image and audio inputs. Here, we show how to process audio inputs. """ + model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. - model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") speech_lora_path = os.path.join(model_path, "speech-lora") - prompts = [ - f"<|user|><|audio_1|>{question}<|end|><|assistant|>" - for question in questions - ] + placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)]) + + prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>" + llm = LLM( model=model_path, trust_remote_code=True, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index b1aec33cff46..6fdd4383c1a1 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -4,13 +4,16 @@ multi-image input on vision language models for text generation, using the chat template defined by the model. """ +import os from argparse import Namespace from typing import NamedTuple, Optional +from huggingface_hub import snapshot_download from PIL.Image import Image from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: + """ + Phi-4-multimodal-instruct supports both image and audio inputs. Here, we + show how to process multi images inputs. + """ + + model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") + # Since the vision-lora and speech-lora co-exist with the base model, + # we have to manually specify the path of the lora weights. + vision_lora_path = os.path.join(model_path, "vision-lora") + llm = LLM( + model=model_path, + trust_remote_code=True, + max_model_len=10000, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enable_lora=True, + max_lora_rank=320, + lora_extra_vocab_size=0, + ) + lora_request = LoRARequest("vision", 1, vision_lora_path) + # To maintain code compatibility in this script, we add LoRA here. + llm.llm_engine.add_lora(lora_request=lora_request) + # You can also add LoRA using: + # llm.generate(prompts, lora_request=lora_request,...) + + placeholders = "".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>" + stop_token_ids = None + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" @@ -459,6 +502,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: "mllama": load_mllama, "NVLM_D": load_nvlm_d, "phi3_v": load_phi3v, + "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl,