Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 119 additions & 84 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
Expand Down Expand Up @@ -423,7 +423,7 @@ def _prepare_model_inputs(
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
Expand Down Expand Up @@ -454,20 +454,29 @@ def _maybe_initialize_input_ids_for_generation(
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
pad_token_id: Optional[Optional[torch.Tensor]],
eos_token_id: Optional[Optional[torch.Tensor]],
) -> torch.LongTensor:
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
if pad_token_id is None:
return default_attention_mask

# Otherwise we have may have information -> try to infer the attention mask
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
is_pad_token_in_inputs = (pad_token_id is not None) and (
torch.isin(elements=inputs, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
)

# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
can_infer_attention_mask = is_input_ids * is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
Expand Down Expand Up @@ -509,8 +518,7 @@ def _prepare_decoder_input_ids_for_generation(
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: Union[int, List[int]] = None,
bos_token_id: int = None,
decoder_start_token_id: Optional[torch.Tensor] = None,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
Expand All @@ -524,20 +532,14 @@ def _prepare_decoder_input_ids_for_generation(
decoder_input_ids = None

# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
if isinstance(decoder_start_token_id, list):
if len(decoder_start_token_id) != batch_size:
if decoder_start_token_id.ndim == 1:
if decoder_start_token_id.shape[0] != batch_size:
raise ValueError(
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
f"`decoder_start_token_id` expcted to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
)
decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
else:
decoder_input_ids_start = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)

# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
Expand Down Expand Up @@ -568,7 +570,7 @@ def _prepare_decoder_input_ids_for_generation(
return decoder_input_ids, model_kwargs

def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
self, decoder_start_token_id: Optional[torch.tensor] = None, bos_token_id: Optional[torch.tensor] = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
Expand Down Expand Up @@ -1207,9 +1209,7 @@ def _prepare_generation_config(
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if is_torchdynamo_compiling():
model_kwargs = kwargs
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
Expand All @@ -1221,6 +1221,38 @@ def _prepare_generation_config(

return generation_config, model_kwargs

def _prepare_special_tokens(
self, generation_config: GenerationConfig, kwargs_has_attention_mask: bool
) -> Tuple[Optional[torch.Tensor]]:
"""Prepares the special tokens for generation."""

# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token):
return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None

bos_token_id = _tensor_or_none(generation_config.bos_token_id)
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id) or bos_token_id

if self.config.is_encoder_decoder and decoder_start_token_id is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)

# Set pad token if unset (and there are conditions to do so)
if pad_token_id is None and eos_token_id is not None:
if not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if eos_token_id.ndim == 1:
pad_token_id = pad_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

return bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1333,63 +1365,49 @@ def generate(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id = self._prepare_special_tokens(
generation_config, kwargs_has_attention_mask
)

# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]

# decoder-only models must use left-padding for generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states

# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, pad_token_id, eos_token_id
)

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)

if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
Expand All @@ -1400,8 +1418,7 @@ def generate(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
decoder_start_token_id=decoder_start_token_id,
device=inputs_tensor.device,
)
else:
Expand Down Expand Up @@ -1446,7 +1463,8 @@ def generate(
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
if not is_torchdynamo_compiling():
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# 7. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
Expand Down Expand Up @@ -1483,6 +1501,7 @@ def generate(
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)

# 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
Expand Down Expand Up @@ -1778,23 +1797,30 @@ def typeerror():

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
def _has_unfinished_sequences(
self, this_peer_finished: bool, cur_len, max_length, synced_gpus: bool, device: torch.device
) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
if is_torchdynamo_compiling():
return cur_len < max_length
else:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
elif this_peer_finished:
return False
return True
return True

def contrastive_search(self, *args, **kwargs):
logger.warning_once(
Expand Down Expand Up @@ -2403,7 +2429,15 @@ def _greedy_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
max_length = None
for criteria in stopping_criteria:
if isinstance(criteria, MaxLengthCriteria):
max_length = criteria.max_length
break

while self._has_unfinished_sequences(
this_peer_finished, cur_len, max_length, synced_gpus, device=input_ids.device
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2470,6 +2504,7 @@ def _greedy_search(

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1

if streamer is not None:
streamer.end()
Expand Down
Loading