@@ -283,6 +283,38 @@ def test_inputs_embeds_matches_input_ids(self):
283
283
out_embeds = model (inputs_embeds = inputs_embeds , ** inputs )[0 ]
284
284
self .assertTrue (torch .allclose (out_embeds , out_ids ))
285
285
286
+ def test_mismatching_num_image_tokens (self ):
287
+ """
288
+ Tests that VLMs through an error with explicit message saying what is wrong
289
+ when number of images don't match number of image tokens in the text.
290
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
291
+ """
292
+ config , input_dict = self .model_tester .prepare_config_and_inputs_for_common ()
293
+ for model_class in self .all_model_classes :
294
+ model = model_class (config ).to (torch_device )
295
+ _ = model (** input_dict ) # successfull forward with no modifications
296
+
297
+ # remove one image but leave the image token in text
298
+ input_dict ["pixel_values" ] = input_dict ["pixel_values" ][- 1 :, ...]
299
+ input_dict ["image_sizes" ] = input_dict ["image_sizes" ][- 1 :, ...]
300
+ with self .assertRaises (ValueError ):
301
+ _ = model (** input_dict )
302
+
303
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
304
+ input_ids = input_dict ["input_ids" ][:1 ]
305
+ pixel_values = input_dict ["pixel_values" ][:1 ]
306
+ image_sizes = input_dict ["image_sizes" ][:1 ]
307
+ input_ids = torch .cat ([input_ids , input_ids ], dim = 0 )
308
+
309
+ # one image and two image tokens raise an error
310
+ with self .assertRaises (ValueError ):
311
+ _ = model (input_ids = input_ids , pixel_values = pixel_values , image_sizes = image_sizes )
312
+
313
+ # two images and two image tokens don't raise an error
314
+ pixel_values = torch .cat ([pixel_values , pixel_values ], dim = 0 )
315
+ image_sizes = torch .cat ([image_sizes , image_sizes ], dim = 0 )
316
+ _ = model (input_ids = input_ids , pixel_values = pixel_values , image_sizes = image_sizes )
317
+
286
318
@unittest .skip (
287
319
reason = "This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
288
320
)
0 commit comments