Skip to content

Conversation

@ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Apr 1, 2024

What does this PR do?

A bug was introduced by #29467 pretty much unrelated to cache positions.
This fixes #29968

cc @gante and @zucchini-nlp. The testing suite is missing this particular test for all generation strategies

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix 👍

@ArthurZucker ArthurZucker marked this pull request as ready for review April 1, 2024 09:51
Comment on lines 721 to 732
input_embeds = model.get_input_embeddings()(input_ids)
beam_kwargs.update({"inputs_embeds": input_embeds})
output_generate2 = self._beam_sample_generate(
model=model,
input_ids=None,
attention_mask=attention_mask,
max_length=max_length,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)

torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be tested in the mixin -- the vast majority of the models don't support passing inputs_embeds to generate, they need would some changes in prepare_inputs_for_generate

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I'll check the signature

@ArthurZucker ArthurZucker merged commit 83b26dd into main Apr 2, 2024
@ArthurZucker ArthurZucker deleted the fix-regression-generate branch April 2, 2024 07:51
@ArthurZucker
Copy link
Collaborator Author

Failing test is unrelated

ArthurZucker added a commit that referenced this pull request Apr 2, 2024
* fix bug and add tests

* nit

* otherway to get the cur len instead of attention mask

* more places where this might have been broken

* nit

* oups

* inputs_embeds vs input_embeds

* test generated outptus

* style

* nit

* fix

* skip failing biogpt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Generating text with Llama 2 doesn't work when num_beams > 1 and only inputs_embeds is provided

3 participants