Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/offline_inference/audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser

audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
Expand Down Expand Up @@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
return llm, prompt, stop_token_ids


# Phi-4-multimodal-instruct
def run_phi4mm(questions: str, audio_count: int):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])

prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"

llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)

stop_token_ids = None
return llm, prompts, stop_token_ids


# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int):
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
Expand Down Expand Up @@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):

model_example_map = {
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"ultravox": run_ultravox,
"whisper": run_whisper,
Expand Down
38 changes: 38 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
import random

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser

# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
Expand Down Expand Up @@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
return llm, prompts, stop_token_ids


# Phi-4-multimodal-instruct
def run_phi4mm(questions: list[str], modality: str):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
"""
assert modality == "image"
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
prompts = [
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
for question in questions
]
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)

stop_token_ids = None
return llm, prompts, stop_token_ids


# Pixtral HF-format
def run_pixtral_hf(questions: list[str], modality: str):
assert modality == "image"
Expand Down Expand Up @@ -644,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
"phi4_mm": run_phi4mm,
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
Expand Down
44 changes: 44 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
multi-image input on vision language models for text generation,
using the chat template defined by the model.
"""
import os
from argparse import Namespace
from typing import NamedTuple, Optional

from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers import AutoProcessor, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
)


def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process multi images inputs.
"""

model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=10000,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)

placeholders = "".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
stop_token_ids = None

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_qwen_vl_chat(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
Expand Down Expand Up @@ -459,6 +502,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
Expand Down
18 changes: 11 additions & 7 deletions vllm/model_executor/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
Expand Down Expand Up @@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"""
Implements the Phi-4-multimodal-instruct model in VLLM.
"""
# LoRA specific attributes
packed_modules_mapping = {
"qkv_proj": [
"qkv_proj",
Expand All @@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"gate_up_proj",
],
}
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -1801,3 +1795,13 @@ def sample(
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.",
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)