-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[tests] fix flaky pattern in test_generate_continue_from_past_key_values
#37724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1096,25 +1096,18 @@ def test_contrastive_generate_low_memory(self): | |
|
|
||
| # test output equality of low versus high memory | ||
| model = model_class(config).to(torch_device).eval() | ||
| generate_kwargs = { | ||
| "top_k": 4, | ||
| "penalty_alpha": 0.6, | ||
| "max_new_tokens": self.max_new_tokens, | ||
| "use_cache": True, | ||
| "return_dict_in_generate": True, | ||
| "output_scores": True, | ||
| } | ||
|
|
||
| low_output = model.generate( | ||
| top_k=4, | ||
| penalty_alpha=0.6, | ||
| low_memory=True, | ||
| max_new_tokens=self.max_new_tokens, | ||
| **inputs_dict, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| high_output = model.generate( | ||
| top_k=4, | ||
| penalty_alpha=0.6, | ||
| low_memory=False, | ||
| max_new_tokens=self.max_new_tokens, | ||
| **inputs_dict, | ||
| use_cache=True, | ||
| ) | ||
| self.assertListEqual(low_output.tolist(), high_output.tolist()) | ||
| low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True) | ||
| high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False) | ||
| self._check_similar_generate_outputs(low_output, high_output) | ||
|
|
||
| @parameterized.expand([("random",), ("same",)]) | ||
| @pytest.mark.generate | ||
|
|
@@ -1862,22 +1855,29 @@ def test_generate_continue_from_past_key_values(self): | |
|
|
||
| model = model_class(config).to(torch_device) | ||
| model.eval() | ||
| model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 | ||
| model.generation_config.forced_eos_token_id = None | ||
| model.generation_config.encoder_no_repeat_ngram_size = 0 | ||
| model.generation_config.use_cache = True | ||
|
|
||
| # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) | ||
| outputs = model(**inputs) | ||
| if "past_key_values" not in outputs: | ||
| self.skipTest(reason="This model doesn't return `past_key_values`") | ||
|
|
||
| generate_kwargs = { | ||
| "pad_token_id": -1, | ||
| "eos_token_id": -1, | ||
| "forced_eos_token_id": None, | ||
| "encoder_no_repeat_ngram_size": 0, | ||
| "use_cache": True, | ||
| "do_sample": False, | ||
| "return_dict_in_generate": True, | ||
| "output_scores": True, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (same, the correct equivalence check needs the scores) |
||
| } | ||
|
|
||
| # Traditional way of generating text, with `return_dict_in_generate` to return the past key values | ||
| outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) | ||
| outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) | ||
|
|
||
| # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the | ||
| # inputs may need to be tweaked across `generate` calls (like the attention mask). | ||
| outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) | ||
| outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) | ||
|
|
||
| # Continue from the tokens generated above, preparing the inputs accordingly | ||
| inputs["past_key_values"] = outputs_cached.past_key_values | ||
|
|
@@ -1900,10 +1900,13 @@ def test_generate_continue_from_past_key_values(self): | |
| mode="constant", | ||
| value=1, | ||
| ) | ||
| outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) | ||
| first_caches_scores = outputs_cached.scores | ||
| outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) | ||
| full_cached_scores = first_caches_scores + outputs_cached.scores | ||
| outputs_cached.scores = full_cached_scores | ||
|
|
||
| # The two sets of generated text and past kv should be equal to each other | ||
| self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) | ||
| self._check_similar_generate_outputs(outputs, outputs_cached) | ||
| for layer_idx in range(len(outputs_cached.past_key_values)): | ||
| for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): | ||
| self.assertTrue( | ||
|
|
@@ -1929,6 +1932,8 @@ def test_generate_continue_from_inputs_embeds(self): | |
|
|
||
| if config.is_encoder_decoder: | ||
| self.skipTest(reason="This model is encoder-decoder") | ||
| # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`, | ||
| # but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern | ||
|
Comment on lines
+1935
to
+1936
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prob it is same thing when we generate from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup 👍 I've pushed the actual fix into the future, in case the test is not (much) flaky |
||
| if not hasattr(config, "use_cache"): | ||
| self.skipTest(reason=f"{model_class.__name__} doesn't support caching") | ||
|
|
||
|
|
@@ -1989,32 +1994,6 @@ def test_generate_continue_from_inputs_embeds(self): | |
| ) | ||
| ) | ||
|
|
||
| @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) | ||
| @require_torch_accelerator | ||
| @pytest.mark.generate | ||
| def test_offloaded_cache_implementation(self, cache_implementation): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test doesn't make sense? tests that two We have tests for the offloaded cache in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, this test was supposed to only check that calling If the tests in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are slow integration tests for all caches on |
||
| """Tests we can generate by indicating `cache_implementation` for each possible cache class""" | ||
| for model_class in self.all_generative_model_classes: | ||
| if not model_class._supports_cache_class: | ||
| self.skipTest(reason="This model does not support the new cache format") | ||
|
|
||
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
|
|
||
| model = model_class(config).to(torch_device).eval() | ||
| generation_kwargs = { | ||
| "max_new_tokens": 5, | ||
| "use_cache": True, | ||
| "cache_implementation": cache_implementation, | ||
| } | ||
|
|
||
| legacy_results = model.generate(**generation_kwargs, **inputs_dict) | ||
|
|
||
| # Most cache classes have their own tests except for some that are tested here | ||
| # The ones here do not need special treatment when passing `cache_implementation` | ||
| # and are not bound to specific models only | ||
| new_results = model.generate(**generation_kwargs, **inputs_dict) | ||
| self.assertListEqual(legacy_results.tolist(), new_results.tolist()) | ||
|
|
||
| @pytest.mark.generate | ||
| def test_generate_with_static_cache(self): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(the correct equivalence check needs the scores)