diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2bc1c351b8f9..d05b2a69544b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1096,25 +1096,18 @@ def test_contrastive_generate_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() + generate_kwargs = { + "top_k": 4, + "penalty_alpha": 0.6, + "max_new_tokens": self.max_new_tokens, + "use_cache": True, + "return_dict_in_generate": True, + "output_scores": True, + } - low_output = model.generate( - top_k=4, - penalty_alpha=0.6, - low_memory=True, - max_new_tokens=self.max_new_tokens, - **inputs_dict, - use_cache=True, - ) - - high_output = model.generate( - top_k=4, - penalty_alpha=0.6, - low_memory=False, - max_new_tokens=self.max_new_tokens, - **inputs_dict, - use_cache=True, - ) - self.assertListEqual(low_output.tolist(), high_output.tolist()) + low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True) + high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False) + self._check_similar_generate_outputs(low_output, high_output) @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate @@ -1862,22 +1855,29 @@ def test_generate_continue_from_past_key_values(self): model = model_class(config).to(torch_device) model.eval() - model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 - model.generation_config.forced_eos_token_id = None - model.generation_config.encoder_no_repeat_ngram_size = 0 - model.generation_config.use_cache = True # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) if "past_key_values" not in outputs: self.skipTest(reason="This model doesn't return `past_key_values`") + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values - outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the # inputs may need to be tweaked across `generate` calls (like the attention mask). - outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) # Continue from the tokens generated above, preparing the inputs accordingly inputs["past_key_values"] = outputs_cached.past_key_values @@ -1900,10 +1900,13 @@ def test_generate_continue_from_past_key_values(self): mode="constant", value=1, ) - outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores # The two sets of generated text and past kv should be equal to each other - self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) + self._check_similar_generate_outputs(outputs, outputs_cached) for layer_idx in range(len(outputs_cached.past_key_values)): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): self.assertTrue( @@ -1929,6 +1932,8 @@ def test_generate_continue_from_inputs_embeds(self): if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder") + # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`, + # but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -1989,32 +1994,6 @@ def test_generate_continue_from_inputs_embeds(self): ) ) - @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) - @require_torch_accelerator - @pytest.mark.generate - def test_offloaded_cache_implementation(self, cache_implementation): - """Tests we can generate by indicating `cache_implementation` for each possible cache class""" - for model_class in self.all_generative_model_classes: - if not model_class._supports_cache_class: - self.skipTest(reason="This model does not support the new cache format") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_new_tokens": 5, - "use_cache": True, - "cache_implementation": cache_implementation, - } - - legacy_results = model.generate(**generation_kwargs, **inputs_dict) - - # Most cache classes have their own tests except for some that are tested here - # The ones here do not need special treatment when passing `cache_implementation` - # and are not bound to specific models only - new_results = model.generate(**generation_kwargs, **inputs_dict) - self.assertListEqual(legacy_results.tolist(), new_results.tolist()) - @pytest.mark.generate def test_generate_with_static_cache(self): """