Skip to content

Commit e68f450

Browse files
committed
update
1 parent 563c71b commit e68f450

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/generation/test_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,9 @@ def test_dola_decoding_sample(self):
12631263

12641264
if model.get_output_embeddings() is None:
12651265
self.skipTest("DoLa is not supported for models that don't have output embeddings")
1266+
1267+
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
1268+
12661269
# Sets dola generation arguments such that:
12671270
# a) no EOS is generated, to ensure generation doesn't break early
12681271
# b) there are at least two forward passes in the main model, to ensure the input preparation of
@@ -1279,9 +1282,8 @@ def test_dola_decoding_sample(self):
12791282
"return_dict_in_generate": True,
12801283
"use_cache": getattr(config, "use_cache", False), # Some models don't support the cache
12811284
"dola_layers": "low",
1282-
"bad_words_ids": [[model.config.image_token_index]] if hasattr(model.config, "image_token_index") else None,
12831285
}
1284-
output_dola = model.generate(**generation_kwargs, **inputs_dict)
1286+
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
12851287
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
12861288

12871289
@pytest.mark.generate

0 commit comments

Comments
 (0)