Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3034,6 +3034,8 @@ def _beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
Expand Down Expand Up @@ -3437,6 +3439,8 @@ def _beam_sample(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# init attention / hidden states / scores tuples
Expand Down Expand Up @@ -3795,6 +3799,8 @@ def _group_beam_search(
device = input_ids.device

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if return_dict_in_generate and output_scores:
Expand Down Expand Up @@ -4211,6 +4217,8 @@ def _constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
Expand Down
13 changes: 13 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,19 @@ def test_beam_sample_generate(self):
)

self.assertTrue(output_generate.shape[-1] == max_length)
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
input_embeds = model.get_input_embeddings()(input_ids)
beam_kwargs.update({"inputs_embeds": input_embeds})
output_generate2 = self._beam_sample_generate(
model=model,
input_ids=None,
attention_mask=attention_mask,
max_length=max_length,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)

torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)

def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/biogpt/test_modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ def test_biogpt_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@unittest.skip("The `input_embeds` when fed don't produce the same results.")
def test_beam_sample_generate(self):
pass


@require_torch
class BioGptModelIntegrationTest(unittest.TestCase):
Expand Down