Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 31 additions & 52 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,25 +1096,18 @@ def test_contrastive_generate_low_memory(self):

# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
generate_kwargs = {
"top_k": 4,
"penalty_alpha": 0.6,
"max_new_tokens": self.max_new_tokens,
"use_cache": True,
"return_dict_in_generate": True,
"output_scores": True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the correct equivalence check needs the scores)

}

low_output = model.generate(
top_k=4,
penalty_alpha=0.6,
low_memory=True,
max_new_tokens=self.max_new_tokens,
**inputs_dict,
use_cache=True,
)

high_output = model.generate(
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_new_tokens=self.max_new_tokens,
**inputs_dict,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
self._check_similar_generate_outputs(low_output, high_output)

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
Expand Down Expand Up @@ -1862,22 +1855,29 @@ def test_generate_continue_from_past_key_values(self):

model = model_class(config).to(torch_device)
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.encoder_no_repeat_ngram_size = 0
model.generation_config.use_cache = True

# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")

generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same, the correct equivalence check needs the scores)

}

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)

# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)

# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
Expand All @@ -1900,10 +1900,13 @@ def test_generate_continue_from_past_key_values(self):
mode="constant",
value=1,
)
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores

# The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
Expand All @@ -1929,6 +1932,8 @@ def test_generate_continue_from_inputs_embeds(self):

if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder")
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
Comment on lines +1935 to +1936
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob it is same thing when we generate from inputs_embeds in VLMs and have to pop image inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup 👍 I've pushed the actual fix into the future, in case the test is not (much) flaky

if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

Expand Down Expand Up @@ -1989,32 +1994,6 @@ def test_generate_continue_from_inputs_embeds(self):
)
)

@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_accelerator
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test doesn't make sense? tests that two generate calls with the same flags have the same output

We have tests for the offloaded cache in tests/utils/test_cache_utils.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this test was supposed to only check that calling genarate with all possible cache_implementation doesn't fail. There was some issue in the past, we lost some "cache_implementation" keys when refactoring 🥲

If the tests in cache_utils cover all caches and call generate with kwargs, I guess we're fine

Copy link
Contributor Author

@gante gante Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are slow integration tests for all caches on main, this PR adds a fast test for all caches 🤗

"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}

legacy_results = model.generate(**generation_kwargs, **inputs_dict)

# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())

@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Expand Down