Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def _test_processing_correctness(
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4",
"openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
Expand Down
236 changes: 159 additions & 77 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@

import torch
from torch import nn
from transformers import PaliGemmaConfig
from transformers import BatchFeature, PaliGemmaConfig

from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

Expand All @@ -46,95 +49,174 @@ class PaliGemmaImageEmbeddingInputs(TypedDict):
PaliGemmaImageEmbeddingInputs]


def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config

return get_max_siglip_image_tokens(vision_config)


def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]

seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)

mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
class PaliGemmaMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states

See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa

multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
class PaliGemmaProcessingInfo(BaseProcessingInfo):

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)
def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)

tokenizer = cached_tokenizer_from_config(model_config)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

orig_prompt = inputs.get("prompt")
orig_prompt_ids = inputs.get("prompt_token_ids")
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}

if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)

new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"

# The PaliGemma 2 tokenizer does not include a starting BOS token
if orig_prompt_ids[0] != hf_config.bos_token_id:
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size

num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)

class PaliGemmaMultiModalProcessor(
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

class PaliGemmaMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
# Paligemma2 is NOT adding <bos> token at the beginning
# Adding <bos> (tokenizer.bos_token_id) for prompt replacement
if len(prompt_ids) == 0:
prompt_ids = [tokenizer.bos_token_id]
elif prompt_ids[0] != tokenizer.bos_token_id:
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
tokenizer = self.info.get_tokenizer()
# Adding <bos> (tokenizer.bos_token_id) to match with
# self._call_hf_processor
if len(prompt_tokens) == 0:
tokens_with_bos = [tokenizer.bos_token_id]
elif prompt_tokens[0] != tokenizer.bos_token_id:
tokens_with_bos = [tokenizer.bos_token_id] + prompt_tokens
else:
tokens_with_bos = prompt_tokens
return tokens_with_bos

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens

bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)

# Adding <bos> token at the beginning based on add_bos_token variable.
# Always adding <bos> token at the end of the images tokens
# and before the text prompt according to Paligemma's format
if tokenizer.add_bos_token:
replacement_tokens = [bos_token_id] + image_tokens + [bos_token_id]
else:
replacement_tokens = image_tokens + [bos_token_id]

return [
PromptReplacement(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=replacement_tokens,
features=image_tokens,
),
)
]

@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
prompt_token_ids = mm_inputs["prompt_token_ids"]

tokenizer = self.info.get_tokenizer()
newline_prompt = "\n"
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
# Force to add newline at the end of prompt for paligemma's format
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
mm_inputs["prompt"] += newline_prompt
return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
Expand Down