Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class ModelRequestData(NamedTuple):
llm: LLM
prompt: str
stop_token_ids: Optional[List[str]]
stop_token_ids: Optional[List[int]]
image_data: List[Image]
chat_template: Optional[str]

Expand All @@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None)
chat_template=None,
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
Expand Down Expand Up @@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)

prompt = f"<|image|><|image|><|begin_of_text|>{question}"
placeholders = "<|image|>" * len(image_urls)
prompt = f"{placeholders}<|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
)


def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to use chat template, you could use this model where I added it https://huggingface.co/mgoin/pixtral-12b

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the prompt format is straightforward enough that we don't need a chat template for this.


# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "[IMG]" * len(image_urls)
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
stop_token_ids = None

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
Expand Down Expand Up @@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
def load_qwen_vl_chat(question: str,
image_urls: List[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
Expand Down Expand Up @@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:

stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

return ModelRequestData(
llm=llm,
prompt=prompt,
Expand Down Expand Up @@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
"qwen_vl_chat": load_qwenvl_chat,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
}

Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if self.config.vision_config.model_type == "pixtral":
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def get_num_image_tokens(
) -> int:
return get_pixtral_hf_image_feature_size(
image_size=self.vision_config.image_size,
patch_size=self.get_image_size(),
patch_size=self.vision_config.patch_size,
)

def get_max_image_tokens(self) -> int:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ def flatten_bn(
...


@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
...


def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
Expand Down
Loading