Skip to content

Generating text with Llama 2 doesn't work when num_beams > 1 and only inputs_embeds is provided #29968

@umbertocappellazzo

Description

@umbertocappellazzo

System Info

  • transformers version: 4.39.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: No

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello guys!

For my project, I’m using Llama2 as LLM which accepts multimodal tokens (audio/video + prompts + text). When I want to generate text given the audio and prompt token embeddings, everything works fine if I use greedy decoding. For example, if our text_embeddings has shape [1,10,4096], where 10 is the number of tokens and 4096 is the hidden size of Llama 2, I generate the output like:

from transformers import LlamaForCausalLM
import torch

llm_name = "meta-llama/Llama-2-7b-hf"
llm = LlamaForCausalLM.from_pretrained(llm_name)
text_embeddings = torch.randn(1,10,4096)

decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens = 10 
                            )

However, if I want to use beam search with N_beams=5, so I also include the num_beams= 5 parameter to the generate model, I get this error:

>>> decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens=10,num_beams=5)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1648, in generate
    result = self._beam_sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3402, in _beam_sample
    outputs = self(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 990, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1077, in _update_causal_mask
    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (10) must match the size of tensor b (0) at non-singleton dimension 0

What causes this error? Do I need to modify something when we want to use beam search? Maybe it depends on the fact that I’m using inputs_embeds rather than input_ids in the generate method and something must be adapted? Based on https://huggingface.co/docs/transformers/v4.38.2/en/generation_strategies#beam-search-decoding it seems like adding num_beams=5 should be sufficient.

Thank you for your help!

Expected behavior

By setting num_beams = N >1, I should swap from greedy decoding to beam search, but while with num_beams = 1 everything works fine, with num_beams >1 I get the above error. I've noticed some changes to the cache_position and similar attributes in the modeling_llama quite recently, maybe those pull requests fixed my error as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions