Skip to content

Commit 7c1e0d3

Browse files
committed
Support E5-V
1 parent 29acd2c commit 7c1e0d3

File tree

9 files changed

+277
-57
lines changed

9 files changed

+277
-57
lines changed

docs/source/models/supported_models.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,12 @@ Multimodal Embedding
484484
- Example HF Models
485485
- :ref:`LoRA <lora>`
486486
- :ref:`PP <distributed_serving>`
487+
* - :code:`LlavaNextForConditionalGeneration`
488+
- LLaVA-NeXT-based
489+
- T + I
490+
- :code:`royokong/e5-v-2`, :code:`royokong/e5-v`
491+
-
492+
- ✅︎
487493
* - :code:`Phi3VForCausalLM`
488494
- Phi-3-Vision-based
489495
- T + I

examples/offline_inference_vision_language.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
This example shows how to use vLLM for running offline inference
3-
with the correct prompt format on vision language models.
2+
This example shows how to use vLLM for running offline inference with
3+
the correct prompt format on vision language models for text generation.
44
55
For most models, the prompt format should follow corresponding examples
66
on HuggingFace model repository.
@@ -450,7 +450,7 @@ def main(args):
450450
if __name__ == "__main__":
451451
parser = FlexibleArgumentParser(
452452
description='Demo on using vLLM for offline inference with '
453-
'vision language models')
453+
'vision language models for text generation')
454454
parser.add_argument('--model-type',
455455
'-m',
456456
type=str,
Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,127 @@
1+
"""
2+
This example shows how to use vLLM for running offline inference with
3+
the correct prompt format on vision language models for multimodal embedding.
4+
5+
For most models, the prompt format should follow corresponding examples
6+
on HuggingFace model repository.
7+
"""
8+
from argparse import Namespace
9+
from typing import List, NamedTuple, Optional, Union
10+
11+
from PIL.Image import Image
12+
113
from vllm import LLM
2-
from vllm.assets.image import ImageAsset
3-
4-
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
5-
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
6-
7-
# Create an LLM.
8-
llm = LLM(
9-
model="TIGER-Lab/VLM2Vec-Full",
10-
task="embedding",
11-
trust_remote_code=True,
12-
max_model_len=4096,
13-
max_num_seqs=2,
14-
mm_processor_kwargs={"num_crops": 16},
15-
)
16-
17-
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
18-
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
19-
20-
# Print the outputs.
21-
for output in outputs:
22-
print(output.outputs.embedding) # list of 3072 floats
14+
from vllm.multimodal.utils import fetch_image
15+
from vllm.utils import FlexibleArgumentParser
16+
17+
18+
class ModelRequestData(NamedTuple):
19+
llm: LLM
20+
prompt: str
21+
stop_token_ids: Optional[List[str]]
22+
image: Optional[Image]
23+
24+
25+
def run_e5_v(text_or_image: Union[str, Image]):
26+
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
27+
28+
if isinstance(text_or_image, str):
29+
prompt = llama3_template.format(
30+
f"{text_or_image}\nSummary above sentence in one word: ")
31+
image = None
32+
else:
33+
prompt = llama3_template.format(
34+
"<image>\nSummary above image in one word: ")
35+
image = text_or_image
36+
37+
llm = LLM(
38+
model="royokong/e5-v-2",
39+
task="embedding",
40+
)
41+
42+
return ModelRequestData(
43+
llm=llm,
44+
prompt=prompt,
45+
stop_token_ids=None,
46+
image=image,
47+
)
48+
49+
50+
def run_vlm2vec(text_or_image: Union[str, Image]):
51+
if isinstance(text_or_image, str):
52+
prompt = f"Find me an everyday image that matches the given caption: {text_or_image}" # noqa: E501
53+
image = None
54+
else:
55+
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
56+
image = text_or_image
57+
58+
llm = LLM(
59+
model="TIGER-Lab/VLM2Vec-Full",
60+
task="embedding",
61+
trust_remote_code=True,
62+
mm_processor_kwargs={"num_crops": 4},
63+
)
64+
65+
return ModelRequestData(
66+
llm=llm,
67+
prompt=prompt,
68+
stop_token_ids=None,
69+
image=image,
70+
)
71+
72+
73+
def get_text_or_image(modality: str):
74+
if modality == "text":
75+
return "A dog sitting in the grass"
76+
77+
if modality == "image":
78+
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg"
79+
return fetch_image(image_url)
80+
81+
msg = f"Modality {modality} is not supported."
82+
raise ValueError(msg)
83+
84+
85+
def run_encode(model: str, modality: str):
86+
text_or_image = get_text_or_image(modality)
87+
req_data = model_example_map[model](text_or_image)
88+
89+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
90+
outputs = req_data.llm.encode(
91+
{
92+
"prompt": req_data.prompt,
93+
"multi_modal_data": {
94+
"image": req_data.image
95+
},
96+
}, )
97+
98+
for output in outputs:
99+
print(output.outputs.embedding)
100+
101+
102+
def main(args: Namespace):
103+
run_encode(args.model, args.modality)
104+
105+
106+
model_example_map = {
107+
"e5_v": run_e5_v,
108+
"vlm2vec": run_vlm2vec,
109+
}
110+
111+
if __name__ == "__main__":
112+
parser = FlexibleArgumentParser(
113+
description='Demo on using vLLM for offline inference with '
114+
'vision language models for multimodal embedding')
115+
parser.add_argument('--model-type',
116+
'-m',
117+
type=str,
118+
default="vlm2vec",
119+
choices=model_example_map.keys(),
120+
help='The name of the embedding model.')
121+
parser.add_argument('--modality',
122+
type=str,
123+
default="image",
124+
choices=['text', 'image'],
125+
help='Modality of the input.')
126+
args = parser.parse_args()
127+
main(args)

examples/offline_inference_vision_language_multi_image.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
This example shows how to use vLLM for running offline inference with
3-
multi-image input on vision language models, using the chat template defined
4-
by the model.
3+
multi-image input on vision language models for text generation,
4+
using the chat template defined by the model.
55
"""
66
from argparse import Namespace
77
from typing import List, NamedTuple, Optional
@@ -334,7 +334,8 @@ def main(args: Namespace):
334334
if __name__ == "__main__":
335335
parser = FlexibleArgumentParser(
336336
description='Demo on using vLLM for offline inference with '
337-
'vision language models that support multi-image input')
337+
'vision language models that support multi-image input for text '
338+
'generation')
338339
parser.add_argument('--model-type',
339340
'-m',
340341
type=str,

tests/conftest.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
4343
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
4444

45-
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
46-
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
47-
List[List[Tuple[np.ndarray, int]]]]
48-
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
45+
_M = TypeVar("_M")
46+
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
47+
48+
PromptImageInput = _PromptMultiModalInput[Image.Image]
49+
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
50+
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
4951

5052

5153
def _read_prompts(filename: str) -> List[str]:
@@ -316,12 +318,12 @@ def get_inputs(
316318
"text": prompt,
317319
"return_tensors": "pt",
318320
}
319-
if images is not None and images[i] is not None:
320-
processor_kwargs["images"] = images[i]
321-
if videos is not None and videos[i] is not None:
322-
processor_kwargs["videos"] = videos[i]
323-
if audios is not None and audios[i] is not None:
324-
audio, sr = audios[i]
321+
if images is not None and (image := images[i]) is not None:
322+
processor_kwargs["images"] = image
323+
if videos is not None and (video := videos[i]) is not None:
324+
processor_kwargs["videos"] = video
325+
if audios is not None and (audio_tuple := audios[i]) is not None:
326+
audio, sr = audio_tuple
325327
processor_kwargs["audio"] = audio
326328
processor_kwargs["sampling_rate"] = sr
327329

@@ -336,7 +338,7 @@ def generate(
336338
self,
337339
prompts: List[str],
338340
images: Optional[PromptImageInput] = None,
339-
videos: Optional[List[np.ndarray]] = None,
341+
videos: Optional[PromptVideoInput] = None,
340342
audios: Optional[PromptAudioInput] = None,
341343
**kwargs: Any,
342344
) -> List[Tuple[List[List[int]], List[str]]]:
@@ -366,7 +368,7 @@ def generate_greedy(
366368
prompts: List[str],
367369
max_tokens: int,
368370
images: Optional[PromptImageInput] = None,
369-
videos: Optional[List[np.ndarray]] = None,
371+
videos: Optional[PromptVideoInput] = None,
370372
audios: Optional[PromptAudioInput] = None,
371373
**kwargs: Any,
372374
) -> List[Tuple[List[int], str]]:
@@ -407,7 +409,7 @@ def generate_greedy_logprobs(
407409
prompts: List[str],
408410
max_tokens: int,
409411
images: Optional[PromptImageInput] = None,
410-
videos: Optional[List[np.ndarray]] = None,
412+
videos: Optional[PromptVideoInput] = None,
411413
audios: Optional[PromptAudioInput] = None,
412414
**kwargs: Any,
413415
) -> List[List[torch.Tensor]]:
@@ -486,7 +488,7 @@ def generate_greedy_logprobs_limit(
486488
num_logprobs: int,
487489
images: Optional[PromptImageInput] = None,
488490
audios: Optional[PromptAudioInput] = None,
489-
videos: Optional[List[np.ndarray]] = None,
491+
videos: Optional[PromptVideoInput] = None,
490492
**kwargs: Any,
491493
) -> List[TokensTextLogprobs]:
492494
all_inputs = self.get_inputs(prompts,
@@ -835,13 +837,20 @@ def generate_beam_search(
835837
returned_outputs.append((token_ids, texts))
836838
return returned_outputs
837839

838-
def encode(self, prompts: List[str]) -> List[List[float]]:
839-
req_outputs = self.model.encode(prompts)
840-
outputs = []
841-
for req_output in req_outputs:
842-
embedding = req_output.outputs.embedding
843-
outputs.append(embedding)
844-
return outputs
840+
def encode(
841+
self,
842+
prompts: List[str],
843+
images: Optional[PromptImageInput] = None,
844+
videos: Optional[PromptVideoInput] = None,
845+
audios: Optional[PromptAudioInput] = None,
846+
) -> List[List[float]]:
847+
inputs = self.get_inputs(prompts,
848+
images=images,
849+
videos=videos,
850+
audios=audios)
851+
852+
req_outputs = self.model.encode(inputs)
853+
return [req_output.outputs.embedding for req_output in req_outputs]
845854

846855
def __enter__(self):
847856
return self
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
import torch.nn.functional as F
3+
4+
from ....conftest import IMAGE_ASSETS
5+
from ..utils import check_embeddings_close
6+
7+
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
8+
9+
HF_TEXT_PROMPTS = [
10+
llama3_template.format(
11+
"The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501
12+
),
13+
llama3_template.format(
14+
"cherry blossom\nSummary above sentence in one word: "),
15+
]
16+
17+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
18+
"stop_sign":
19+
llama3_template.format("<image>\nSummary above image in one word: "),
20+
"cherry_blossom":
21+
llama3_template.format("<image>\nSummary above image in one word: "),
22+
})
23+
24+
MODELS = ["royokong/e5-v-2"]
25+
26+
27+
@pytest.mark.parametrize("model", MODELS)
28+
@pytest.mark.parametrize("dtype", ["half"])
29+
def test_models(
30+
hf_runner,
31+
vllm_runner,
32+
image_assets,
33+
model: str,
34+
dtype: str,
35+
) -> None:
36+
input_texts_images = [
37+
*((text, None) for text in HF_TEXT_PROMPTS),
38+
*((text, image)
39+
for text, image in zip(HF_IMAGE_PROMPTS, image_assets)),
40+
]
41+
input_texts = [text for text, _ in input_texts_images]
42+
input_images = [image for _, image in input_texts_images]
43+
44+
# NOTE: take care of the order. run vLLM first, and then run HF.
45+
# vLLM needs a fresh new process without cuda initialization.
46+
# if we run HF first, the cuda initialization will be done and it
47+
# will hurt multiprocessing backend with fork method (the default method).
48+
with vllm_runner(model, task="embedding", dtype=dtype,
49+
enforce_eager=True) as vllm_model:
50+
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
51+
52+
with hf_runner(model, dtype=dtype) as hf_model:
53+
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
54+
55+
all_outputs = []
56+
for inputs in all_inputs:
57+
# Based on: https://huggingface.co/royokong/e5-v
58+
outputs = hf_model.model(
59+
**hf_model.wrap_device(inputs,
60+
device=hf_model.model.device.type),
61+
return_dict=True,
62+
output_hidden_states=True,
63+
)
64+
pooled_output = F.normalize(outputs.hidden_states[-1][:, -1, :],
65+
dim=-1)
66+
67+
all_outputs.append(pooled_output.tolist())
68+
69+
hf_outputs = all_outputs
70+
71+
check_embeddings_close(
72+
embeddings_0_lst=hf_outputs,
73+
embeddings_1_lst=vllm_outputs,
74+
name_0="hf",
75+
name_1="vllm",
76+
)

0 commit comments

Comments
 (0)