-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Description
System Info
transformersversion: 4.39.0- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Python version: 3.10.13
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.2
- Accelerate version: 0.28.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: No
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hello guys!
For my project, I’m using Llama2 as LLM which accepts multimodal tokens (audio/video + prompts + text). When I want to generate text given the audio and prompt token embeddings, everything works fine if I use greedy decoding. For example, if our text_embeddings has shape [1,10,4096], where 10 is the number of tokens and 4096 is the hidden size of Llama 2, I generate the output like:
from transformers import LlamaForCausalLM
import torch
llm_name = "meta-llama/Llama-2-7b-hf"
llm = LlamaForCausalLM.from_pretrained(llm_name)
text_embeddings = torch.randn(1,10,4096)
decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens = 10
)
However, if I want to use beam search with N_beams=5, so I also include the num_beams= 5 parameter to the generate model, I get this error:
>>> decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens=10,num_beams=5)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1648, in generate
result = self._beam_sample(
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3402, in _beam_sample
outputs = self(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
outputs = self.model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 990, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1077, in _update_causal_mask
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (10) must match the size of tensor b (0) at non-singleton dimension 0
What causes this error? Do I need to modify something when we want to use beam search? Maybe it depends on the fact that I’m using inputs_embeds rather than input_ids in the generate method and something must be adapted? Based on https://huggingface.co/docs/transformers/v4.38.2/en/generation_strategies#beam-search-decoding it seems like adding num_beams=5 should be sufficient.
Thank you for your help!
Expected behavior
By setting num_beams = N >1, I should swap from greedy decoding to beam search, but while with num_beams = 1 everything works fine, with num_beams >1 I get the above error. I've noticed some changes to the cache_position and similar attributes in the modeling_llama quite recently, maybe those pull requests fixed my error as well.