Skip to content

Commit 7da0eef

Browse files
zucchini-nlpArthurZucker
authored andcommitted
VLMs: fix number of image tokens (#34332)
* fix * fix tests * add tests * style * style * fix qwen after rebase * fix video llava
1 parent bc598c0 commit 7da0eef

File tree

15 files changed

+237
-15
lines changed

15 files changed

+237
-15
lines changed

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ def forward(
12881288
if pixel_values is not None:
12891289
image_tokens = self.get_image_tokens(pixel_values)
12901290
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
1291-
n_image_features = image_tokens.shape[0]
1291+
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
12921292
if n_image_tokens_in_text != n_image_features:
12931293
raise ValueError(
12941294
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"

src/transformers/models/llava/modeling_llava.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,9 @@ def forward(
522522

523523
# TODO: @raushan retain only the new behavior after v4.47
524524
elif image_features is not None:
525-
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
526-
n_image_features = image_features.shape[1]
525+
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
526+
n_image_features = image_features.shape[0] * image_features.shape[1]
527+
527528
if n_image_tokens != n_image_features:
528529
raise ValueError(
529530
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

src/transformers/models/llava_next_video/modeling_llava_next_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,7 @@ def forward(
10201020
if image_features is not None:
10211021
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
10221022
n_image_features = image_features.shape[0]
1023+
10231024
if n_image_tokens != n_image_features:
10241025
raise ValueError(
10251026
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

src/transformers/models/llava_next_video/modular_llava_next_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def forward(
533533
if image_features is not None:
534534
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
535535
n_image_features = image_features.shape[0]
536+
536537
if n_image_tokens != n_image_features:
537538
raise ValueError(
538539
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

src/transformers/models/llava_onevision/modeling_llava_onevision.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ def forward(
679679
)
680680
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
681681
n_image_features = image_features.shape[0]
682+
682683
if n_image_tokens != n_image_features:
683684
raise ValueError(
684685
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@@ -704,6 +705,7 @@ def forward(
704705
)
705706
video_features = torch.cat((video_features, image_newline), dim=1)
706707
video_features = video_features.flatten(0, 1)
708+
707709
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
708710
n_video_features = video_features.shape[0]
709711
if n_video_tokens != n_video_features:

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,13 +1503,14 @@ def get_rope_index(
15031503
mrope_position_deltas = []
15041504
if image_grid_thw is not None or video_grid_thw is not None:
15051505
total_input_ids = input_ids
1506+
if attention_mask is None:
1507+
attention_mask = torch.ones_like(total_input_ids)
15061508
position_ids = torch.ones(
15071509
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
15081510
)
15091511
image_index, video_index = 0, 0
15101512
for i, input_ids in enumerate(total_input_ids):
1511-
if attention_mask is not None:
1512-
input_ids = input_ids[attention_mask[i] == 1]
1513+
input_ids = input_ids[attention_mask[i] == 1]
15131514
image_nums, video_nums = 0, 0
15141515
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
15151516
vision_tokens = input_ids[vision_start_indices + 1]

src/transformers/models/video_llava/modeling_video_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,8 @@ def forward(
622622
# TODO: @raushan retain only the new behavior after v4.47
623623
else:
624624
if pixel_values_images is not None:
625-
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
626-
n_image_features = image_features.shape[1]
625+
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
626+
n_image_features = image_features.shape[0] * image_features.shape[1]
627627
if n_image_tokens != n_image_features:
628628
raise ValueError(
629629
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@@ -638,8 +638,8 @@ def forward(
638638
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
639639

640640
if pixel_values_videos is not None:
641-
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
642-
n_video_features = video_features.shape[1]
641+
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
642+
n_video_features = video_features.shape[0] * video_features.shape[1]
643643
if n_video_tokens != n_video_features:
644644
raise ValueError(
645645
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

src/transformers/models/vipllava/modeling_vipllava.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,8 @@ def forward(
512512

513513
# TODO: @raushan retain only the new behavior after v4.47
514514
elif image_features is not None:
515-
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
516-
n_image_features = image_features.shape[1]
515+
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
516+
n_image_features = image_features.shape[0] * image_features.shape[1]
517517
if n_image_tokens != n_image_features:
518518
raise ValueError(
519519
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

tests/models/llava/test_modeling_llava.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,35 @@ def test_inputs_embeds_matches_input_ids(self):
235235
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
236236
self.assertTrue(torch.allclose(out_embeds, out_ids))
237237

238+
def test_mismatching_num_image_tokens(self):
239+
"""
240+
Tests that VLMs through an error with explicit message saying what is wrong
241+
when number of images don't match number of image tokens in the text.
242+
Also we need to test multi-image cases when one prompr has multiple image tokens.
243+
"""
244+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
245+
for model_class in self.all_model_classes:
246+
model = model_class(config).to(torch_device)
247+
_ = model(**input_dict) # successfull forward with no modifications
248+
249+
# remove one image but leave the image token in text
250+
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
251+
with self.assertRaises(ValueError):
252+
_ = model(**input_dict)
253+
254+
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
255+
input_ids = input_dict["input_ids"][:1]
256+
pixel_values = input_dict["pixel_values"][:1]
257+
input_ids = torch.cat([input_ids, input_ids], dim=0)
258+
259+
# one image and two image tokens raise an error
260+
with self.assertRaises(ValueError):
261+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
262+
263+
# two images and two image tokens don't raise an error
264+
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
265+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
266+
238267
@unittest.skip(
239268
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
240269
)

tests/models/llava_next/test_modeling_llava_next.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,38 @@ def test_inputs_embeds_matches_input_ids(self):
283283
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
284284
self.assertTrue(torch.allclose(out_embeds, out_ids))
285285

286+
def test_mismatching_num_image_tokens(self):
287+
"""
288+
Tests that VLMs through an error with explicit message saying what is wrong
289+
when number of images don't match number of image tokens in the text.
290+
Also we need to test multi-image cases when one prompr has multiple image tokens.
291+
"""
292+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
293+
for model_class in self.all_model_classes:
294+
model = model_class(config).to(torch_device)
295+
_ = model(**input_dict) # successfull forward with no modifications
296+
297+
# remove one image but leave the image token in text
298+
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
299+
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
300+
with self.assertRaises(ValueError):
301+
_ = model(**input_dict)
302+
303+
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
304+
input_ids = input_dict["input_ids"][:1]
305+
pixel_values = input_dict["pixel_values"][:1]
306+
image_sizes = input_dict["image_sizes"][:1]
307+
input_ids = torch.cat([input_ids, input_ids], dim=0)
308+
309+
# one image and two image tokens raise an error
310+
with self.assertRaises(ValueError):
311+
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
312+
313+
# two images and two image tokens don't raise an error
314+
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
315+
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
316+
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
317+
286318
@unittest.skip(
287319
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
288320
)

0 commit comments

Comments
 (0)