Skip to content

Commit fab5f53

Browse files
authored
[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)
1 parent 9c71c97 commit fab5f53

File tree

15 files changed

+214
-60
lines changed

15 files changed

+214
-60
lines changed

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ Base Classes
4545

4646
.. autodata:: vllm.multimodal.NestedTensors
4747

48-
.. autodata:: vllm.multimodal.BatchedTensors
49-
5048
.. autodata:: vllm.multimodal.BatchedTensorInputs
5149

5250
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins

tests/multimodal/test_base.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
3+
from vllm.multimodal.base import MultiModalInputs, NestedTensors
4+
5+
6+
def assert_nested_tensors_equal(expected: NestedTensors,
7+
actual: NestedTensors):
8+
assert type(expected) == type(actual)
9+
if isinstance(expected, torch.Tensor):
10+
assert torch.equal(expected, actual)
11+
else:
12+
for expected_item, actual_item in zip(expected, actual):
13+
assert_nested_tensors_equal(expected_item, actual_item)
14+
15+
16+
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
17+
actual: MultiModalInputs):
18+
assert set(expected.keys()) == set(actual.keys())
19+
for key in expected:
20+
assert_nested_tensors_equal(expected[key], actual[key])
21+
22+
23+
def test_multimodal_input_batch_single_tensor():
24+
t = torch.rand([1, 2])
25+
result = MultiModalInputs.batch([{"image": t}])
26+
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
27+
28+
29+
def test_multimodal_input_batch_multiple_tensors():
30+
a = torch.rand([1, 1, 2])
31+
b = torch.rand([1, 1, 2])
32+
c = torch.rand([1, 1, 2])
33+
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
34+
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
35+
36+
37+
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
38+
a = torch.rand([1, 2, 2])
39+
b = torch.rand([1, 3, 2])
40+
c = torch.rand([1, 4, 2])
41+
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
42+
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
43+
44+
45+
def test_multimodal_input_batch_nested_tensors():
46+
a = torch.rand([2, 3])
47+
b = torch.rand([2, 3])
48+
c = torch.rand([2, 3])
49+
result = MultiModalInputs.batch([{
50+
"image": [a]
51+
}, {
52+
"image": [b]
53+
}, {
54+
"image": [c]
55+
}])
56+
assert_multimodal_inputs_equal(result, {
57+
"image":
58+
torch.stack([a.unsqueeze(0),
59+
b.unsqueeze(0),
60+
c.unsqueeze(0)])
61+
})
62+
63+
64+
def test_multimodal_input_batch_heterogeneous_lists():
65+
a = torch.rand([1, 2, 3])
66+
b = torch.rand([1, 2, 3])
67+
c = torch.rand([1, 2, 3])
68+
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
69+
assert_multimodal_inputs_equal(
70+
result,
71+
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
72+
73+
74+
def test_multimodal_input_batch_multiple_batchable_lists():
75+
a = torch.rand([1, 2, 3])
76+
b = torch.rand([1, 2, 3])
77+
c = torch.rand([1, 2, 3])
78+
d = torch.rand([1, 2, 3])
79+
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
80+
assert_multimodal_inputs_equal(
81+
result,
82+
{"image": torch.stack([torch.stack([a, b]),
83+
torch.stack([c, d])])})

vllm/model_executor/models/blip2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ def _parse_and_validate_image_input(
555555
raise ValueError("Incorrect type of pixel values. "
556556
f"Got type: {type(pixel_values)}")
557557

558+
# Remove the N dimension until multiple images are supported.
559+
pixel_values = pixel_values.squeeze(1)
560+
558561
return Blip2ImagePixelInputs(
559562
type="pixel_values",
560563
data=self._validate_pixel_values(pixel_values),
@@ -564,6 +567,10 @@ def _parse_and_validate_image_input(
564567
if not isinstance(image_embeds, torch.Tensor):
565568
raise ValueError("Incorrect type of image embeddings. "
566569
f"Got type: {type(image_embeds)}")
570+
571+
# Remove the N dimension until multiple images are supported.
572+
image_embeds = image_embeds.squeeze(1)
573+
567574
return Blip2ImageEmbeddingInputs(
568575
type="image_embeds",
569576
data=image_embeds,

vllm/model_executor/models/chameleon.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,9 @@ def _parse_and_validate_image_input(
946946
raise ValueError("Incorrect type of pixel values. "
947947
f"Got type: {type(pixel_values)}")
948948

949+
# Remove the N dimension until multiple images are supported.
950+
pixel_values = pixel_values.squeeze(1)
951+
949952
return ChameleonImagePixelInputs(
950953
type="pixel_values",
951954
data=self._validate_pixel_values(pixel_values),

vllm/model_executor/models/fuyu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def _parse_and_validate_image_input(
249249
image_patches = kwargs.pop("image_patches", None)
250250

251251
if isinstance(image_patches, torch.Tensor):
252+
# Remove the N dimension until multiple images are supported.
253+
image_patches = image_patches.squeeze(1)
254+
252255
expected_feature_size = self.image_feature_size
253256
if image_patches.size(-1) != expected_feature_size:
254257
raise ValueError(

vllm/model_executor/models/internvl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
244244
min_num,
245245
max_num,
246246
use_thumbnail=use_thumbnail)
247+
# Add an N dimension for number of images per prompt (currently 1).
248+
data = data.unsqueeze(0)
247249
model_config = ctx.model_config
248250
tokenizer = cached_get_tokenizer(model_config.tokenizer,
249251
trust_remote_code=True)
@@ -410,6 +412,10 @@ def _parse_and_validate_image_input(
410412
if not isinstance(image_embeds, torch.Tensor):
411413
raise ValueError("Incorrect type of image embeddings. "
412414
f"Got type: {type(image_embeds)}")
415+
416+
# Flatten the B and N dimensions
417+
image_embeds = image_embeds.flatten(0, 2)
418+
413419
return InternVLImageEmbeddingInputs(
414420
type="image_embeds",
415421
data=image_embeds,
@@ -422,6 +428,9 @@ def _parse_and_validate_image_input(
422428
raise ValueError("Incorrect type of pixel values. "
423429
f"Got type: {type(pixel_values)}")
424430

431+
# Flatten the B and N dimensions
432+
pixel_values = pixel_values.flatten(0, 2)
433+
425434
return InternVLImagePixelInputs(
426435
type="pixel_values",
427436
data=self._validate_pixel_values(pixel_values),

vllm/model_executor/models/llava.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def _parse_and_validate_image_input(
232232
if not isinstance(pixel_values, torch.Tensor):
233233
raise ValueError("Incorrect type of pixel values. "
234234
f"Got type: {type(pixel_values)}")
235+
236+
# Remove the N dimension until multiple images are supported.
237+
pixel_values = pixel_values.squeeze(1)
238+
235239
return LlavaImagePixelInputs(
236240
type="pixel_values",
237241
data=self._validate_pixel_values(pixel_values),
@@ -241,6 +245,10 @@ def _parse_and_validate_image_input(
241245
if not isinstance(image_embeds, torch.Tensor):
242246
raise ValueError("Incorrect type of image embeddings. "
243247
f"Got type: {type(image_embeds)}")
248+
249+
# Remove the N dimension until multiple images are supported.
250+
image_embeds = image_embeds.squeeze(1)
251+
244252
return LlavaImageEmbeddingInputs(
245253
type="image_embeds",
246254
data=image_embeds,

vllm/model_executor/models/llava_next.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,14 @@ def _parse_and_validate_image_input(
361361
raise ValueError("Incorrect type of image sizes. "
362362
f"Got type: {type(image_sizes)}")
363363

364+
# Remove the N dimension until multiple images are supported.
365+
if isinstance(pixel_values, torch.Tensor):
366+
pixel_values = pixel_values.squeeze(1)
367+
else:
368+
pixel_values = [t.squeeze(0) for t in pixel_values]
369+
370+
image_sizes = image_sizes.squeeze(1)
371+
364372
return LlavaNextImagePixelInputs(
365373
type="pixel_values",
366374
data=self._validate_pixel_values(pixel_values),
@@ -372,6 +380,9 @@ def _parse_and_validate_image_input(
372380
raise ValueError("Incorrect type of image embeds. "
373381
f"Got type: {type(image_embeds)}")
374382

383+
# Remove the N dimension until multiple images are supported.
384+
image_embeds = image_embeds.squeeze(1)
385+
375386
return LlavaNextImageEmbeddingInputs(
376387
type="image_embeds",
377388
data=image_embeds,

vllm/model_executor/models/minicpmv.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,14 @@ def _parse_and_validate_inputs(
594594

595595
pixel_values_flat: List[torch.Tensor] = []
596596
tgt_sizes_flat: List[torch.Tensor] = []
597-
for b in range(len(pixel_values)):
598-
pixel_values_flat += pixel_values[b]
599-
tgt_sizes_flat += tgt_sizes[b]
597+
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
598+
if len(pixel_b) != len(tgt_b):
599+
raise ValueError("Inconsistent N lengths, found: "
600+
f"{len(pixel_b)} vs {len(tgt_b)}")
601+
602+
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
603+
pixel_values_flat += pixel_n
604+
tgt_sizes_flat += tgt_n
600605

601606
# NOTE: Input IDs does not contain image tokens during memory profiling,
602607
# so we allow it to be empty

vllm/model_executor/models/paligemma.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def _parse_and_validate_image_input(
185185
if not isinstance(pixel_values, torch.Tensor):
186186
raise ValueError("Incorrect type of pixel values. "
187187
f"Got type: {type(pixel_values)}")
188+
189+
# Remove the N dimension until multiple images are supported.
190+
pixel_values = pixel_values.squeeze(1)
191+
188192
return PaliGemmaImagePixelInputs(
189193
type="pixel_values",
190194
data=self._validate_pixel_values(pixel_values),
@@ -194,6 +198,10 @@ def _parse_and_validate_image_input(
194198
if not isinstance(image_embeds, torch.Tensor):
195199
raise ValueError("Incorrect type of image embeddings. "
196200
f"Got type: {type(image_embeds)}")
201+
202+
# Remove the N dimension until multiple images are supported.
203+
image_embeds = image_embeds.squeeze(1)
204+
197205
return PaliGemmaImageEmbeddingInputs(
198206
type="image_embeds",
199207
data=image_embeds,

0 commit comments

Comments
 (0)