From 956c0b8f69ac7454e8b24eb736a1ae8a9f7fa974 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 2 May 2024 14:17:53 +0000 Subject: [PATCH 1/9] tmp commit --- src/transformers/generation/utils.py | 191 +++++++++++++++------------ 1 file changed, 105 insertions(+), 86 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7f4caf26aeac..f22743ab0a7f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -352,7 +352,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs): def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ @@ -416,7 +416,7 @@ def _prepare_model_inputs( def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" @@ -448,20 +448,31 @@ def _maybe_initialize_input_ids_for_generation( def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, - pad_token_id: Optional[int], - eos_token_id: Optional[Union[int, List[int]]], + pad_token_id: Optional[torch.Tensor], + eos_token_id: Optional[torch.Tensor], ) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] - is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) + if not is_input_ids: + return default_attention_mask - # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: - return inputs.ne(pad_token_id).long() - else: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + # Otherwise we have may have information -> try to infer the attention mask + is_pad_token_in_inputs = (pad_token_id is not None) and ( + torch.isin(elements=inputs, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + torch.isin(elements=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask def _prepare_encoder_decoder_kwargs_for_generation( self, @@ -509,8 +520,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: Union[int, List[int]] = None, - bos_token_id: int = None, + decoder_start_token_id: torch.Tensor, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" @@ -523,25 +533,24 @@ def _prepare_decoder_input_ids_for_generation( else: decoder_input_ids = None - # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + # 2. `decoder_start_token_id` must have shape (batch_size, 1) if device is None: device = self.device - if isinstance(decoder_start_token_id, list): - if len(decoder_start_token_id) != batch_size: + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: raise ValueError( - f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" ) - decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) - decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) else: - decoder_input_ids_start = ( + decoder_start_token_id = ( torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id ) + # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: - decoder_input_ids = decoder_input_ids_start + decoder_input_ids = decoder_start_token_id # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass @@ -549,14 +558,8 @@ def _prepare_decoder_input_ids_for_generation( pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) - elif ( - isinstance(decoder_start_token_id, int) - and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item() - ) or ( - isinstance(decoder_start_token_id, torch.Tensor) - and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() - ): - decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): + decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( @@ -567,24 +570,6 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs - def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None - ) -> int: - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, @@ -728,6 +713,8 @@ def _get_logits_warper( if generation_config.num_beams > 1: if isinstance(generation_config.eos_token_id, list): min_tokens_to_keep = len(generation_config.eos_token_id) + 1 + elif isinstance(generation_config.eos_token_id, torch.Tensor): + min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1 else: min_tokens_to_keep = 2 else: @@ -1342,6 +1329,54 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa self._static_cache.reset() # reset the cache for a new generation return self._static_cache + def _prepare_special_tokens( + self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None + ) -> Tuple[Optional[torch.Tensor]]: + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + Note that `generation_config` is changed in place and stops being serializable after this method is called. + If called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors (if they exist) + def _tensor_or_none(token): + return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None + + bos_token_id = _tensor_or_none(generation_config.bos_token_id) + eos_token_id = _tensor_or_none(generation_config.eos_token_id) + pad_token_id = _tensor_or_none(generation_config.pad_token_id) + decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id) + decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_id is not None and eos_token_id.ndim == 0: + eos_token_id = eos_token_id.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_id is None and eos_token_id is not None: + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + pad_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + + # Sanity checks + if self.config.is_encoder_decoder and decoder_start_token_id is None: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + # Update generation config with the updated special tokens + generation_config.bos_token_id = bos_token_id + generation_config.eos_token_id = eos_token_id + generation_config.pad_token_id = pad_token_id + generation_config.decoder_start_token_id = decoder_start_token_id + + return generation_config + @torch.no_grad() def generate( self, @@ -1456,46 +1491,19 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + generation_config = self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] - # 4. Define other model kwargs - # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are - # generating the first new token or not, and we only want to use the embeddings for the first new token) - if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": - model_kwargs["use_cache"] = True - else: - model_kwargs["use_cache"] = generation_config.use_cache - - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id - ) - - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: + # decoder-only models must use left-padding for generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if ( @@ -1508,9 +1516,21 @@ def generate( "generation results, please set `padding_side='left'` when initializing the tokenizer." ) + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + ) + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) @@ -1522,7 +1542,6 @@ def generate( model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, device=inputs_tensor.device, ) else: From b5bebe98c4ea5ff246237b3fdb806721e5587099 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 2 May 2024 15:29:46 +0000 Subject: [PATCH 2/9] [test_all] mvp --- src/transformers/generation/logits_process.py | 141 +++++++++++------- .../generation/stopping_criteria.py | 19 ++- src/transformers/generation/utils.py | 8 +- .../models/musicgen/modeling_musicgen.py | 20 ++- tests/generation/test_utils.py | 5 +- .../test_modeling_seamless_m4t.py | 4 +- .../test_modeling_seamless_m4t_v2.py | 4 +- .../test_modeling_speech_to_text.py | 4 +- tests/models/whisper/test_modeling_whisper.py | 6 +- 9 files changed, 138 insertions(+), 73 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ce91e8a40a4e..89e69cb24643 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor): Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. Examples: @@ -137,14 +137,17 @@ class MinLengthLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): + def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}") - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): - logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") self.min_length = min_length self.eos_token_id = eos_token_id @@ -152,8 +155,8 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, eos_token_id) + self.eos_token_id = self.eos_token_id.to(scores.device) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -171,8 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): input length. min_new_tokens (`int`): The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. Examples: @@ -195,7 +198,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]): + def __init__( + self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor] + ): for arg_name, arg_value in [ ("prompt_length_to_skip", prompt_length_to_skip), ("min_new_tokens", min_new_tokens), @@ -203,10 +208,13 @@ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id if not isinstance(arg_value, int) or arg_value < 0: raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): - logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens @@ -217,8 +225,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip scores_processed = scores.clone() vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, eos_token_id) + self.eos_token_id = self.eos_token_id.to(scores.device) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) if new_tokens_length < self.min_new_tokens: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -1118,8 +1126,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): Args: bad_words_ids (`List[List[int]]`): List of list of token ids that are not allowed to be generated. - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*): + The id(s) of the *end-of-sequence* token. Examples: @@ -1156,18 +1164,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): ``` """ - def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): + def __init__( + self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None + ): self.bad_word_ids = bad_words_ids self._validate_arguments() # Filter EOS token from bad_words_ids if eos_token_id is None: - eos_token_id = [] - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - bad_words_ids = list( - filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids) - ) + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + bad_words_ids = list( + filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids) + ) # Forbidding a sequence is equivalent to setting its bias to -inf sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids} @@ -1445,9 +1457,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): Args: max_length (`int`): The maximum length of the sequence to be generated. - eos_token_id (`Union[int, List[int]]`): - The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a - list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. Examples: @@ -1471,15 +1482,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): + def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): self.max_length = max_length - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) self.eos_token_id = eos_token_id + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] + self.eos_token_id = self.eos_token_id.to(scores.device) scores_processed = scores if cur_len == self.max_length - 1: scores_processed = torch.full_like(scores, -math.inf) @@ -1518,8 +1536,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): exponential_decay_length_penalty (`tuple(int, float)`): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. input_ids_seq_length (`int`): The length of the input sequence. @@ -1579,27 +1597,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): def __init__( self, exponential_decay_length_penalty: Tuple[int, float], - eos_token_id: Union[int, List[int]], + eos_token_id: Union[int, List[int], torch.Tensor], input_ids_seq_length: int, ): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) self.eos_token_id = eos_token_id + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] + self.eos_token_id = self.eos_token_id.to(scores.device) penalties = torch.zeros_like(scores) scores_processed = scores if cur_len > self.regulation_start: - for i in self.eos_token_id: - penalty_idx = cur_len - self.regulation_start - # To support negative logits we compute the penalty of the absolute value and add to the original logit - penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1) - penalties[:, i] = penalty - scores_processed = scores + penalties + penalty_idx = cur_len - self.regulation_start + # To support negative logits we compute the penalty of the absolute value and add to the original logit + penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1) + penalties[:, self.eos_token_id] = penalty + scores_processed = scores + penalties return scores_processed @@ -1676,7 +1700,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): """ def __init__(self, begin_suppress_tokens, begin_index): - self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens)) self.begin_index = begin_index def set_begin_index(self, begin_index): @@ -1685,8 +1709,8 @@ def set_begin_index(self, begin_index): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens) + self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device) + suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) scores_processed = scores if input_ids.shape[-1] == self.begin_index: scores_processed = torch.where(suppress_token_mask, -float("inf"), scores) @@ -1724,13 +1748,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): """ def __init__(self, suppress_tokens): - self.suppress_tokens = list(suppress_tokens) + self.suppress_tokens = torch.tensor(list(suppress_tokens)) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens) + self.suppress_tokens = self.suppress_tokens.to(scores.device) + suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores @@ -2191,16 +2215,21 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): Args: - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. min_eos_p (`float`, *optional*): Minimum end of speech threshold. """ - def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.eos_token_id = eos_token_id + def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float): + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + if min_eos_p is not None and min_eos_p <= 0: raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") self.min_eos_p = min_eos_p diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 48392400c4f4..5fd1c2f773ab 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -470,29 +470,32 @@ class EosTokenCriteria(StoppingCriteria): By default, it uses the `model.generation_config.eos_token_id`. Args: - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + eos_token_id (`Union[int, List[int], torch.Tensor]`): + The id(s) of the *end-of-sequence* token. """ - def __init__(self, eos_token_id: Union[int, List[int]]): - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.eos_token_id = torch.tensor(eos_token_id) + def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]): + if not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + self.eos_token_id = eos_token_id @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + self.eos_token_id = self.eos_token_id.to(input_ids.device) if input_ids.device.type == "mps": # https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 is_done = ( input_ids[:, -1] .tile(self.eos_token_id.shape[0], 1) - .eq(self.eos_token_id.unsqueeze(1).to(input_ids.device)) + .eq(self.eos_token_id.unsqueeze(1)) .sum(dim=0) .bool() .squeeze() ) else: - is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device)) + is_done = torch.isin(input_ids[:, -1], self.eos_token_id) return is_done diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f22743ab0a7f..808741db05df 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -461,6 +461,12 @@ def _prepare_attention_mask_for_generation( return default_attention_mask # Otherwise we have may have information -> try to infer the attention mask + if inputs.device.type == "mps": + # mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764) + raise ValueError( + "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device." + ) + is_pad_token_in_inputs = (pad_token_id is not None) and ( torch.isin(elements=inputs, test_elements=pad_token_id).any() ) @@ -533,7 +539,7 @@ def _prepare_decoder_input_ids_for_generation( else: decoder_input_ids = None - # 2. `decoder_start_token_id` must have shape (batch_size, 1) + # 2. `decoder_start_token_id` must have shape (batch_size, 1) if device is None: device = self.device if decoder_start_token_id.ndim == 1: diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 08f42ce69e18..1bf1c0bfe89a 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -18,7 +18,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -2590,6 +2590,24 @@ def _maybe_initialize_input_ids_for_generation( break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + ) -> int: + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + @torch.no_grad() def generate( self, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index eacba9ebc6f4..fe5e94310e04 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -14,6 +14,7 @@ # limitations under the License. +import copy import inspect import tempfile import unittest @@ -168,7 +169,9 @@ def _get_encoder_outputs( encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) - input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() + generation_config = copy.deepcopy(model.generation_config) + generation_config = model._prepare_special_tokens(generation_config) + input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index d77aac6187a0..b286f2c61cd3 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -414,9 +414,11 @@ def _get_encoder_outputs( encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) + generation_config = copy.deepcopy(model.generation_config) + generation_config = model._prepare_special_tokens(generation_config) input_ids = ( torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) - + model._get_decoder_start_token_id() + + generation_config.decoder_start_token_id ) attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 301a1eb44ba6..47d8ff91164c 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -430,9 +430,11 @@ def _get_encoder_outputs( encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) + generation_config = copy.deepcopy(model.generation_config) + generation_config = model._prepare_special_tokens(generation_config) input_ids = ( torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) - + model._get_decoder_start_token_id() + + generation_config.decoder_start_token_id ) attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index f3fc72ab8ed4..b5dbed1d3de8 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -645,7 +645,9 @@ def _get_encoder_outputs( num_interleave, dim=0 ) input_ids = input_ids[:, :, 0] - input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id() + generation_config = copy.deepcopy(model.generation_config) + generation_config = model._prepare_special_tokens(generation_config) + input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 44b6c1ea749e..0c45b18f0547 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -833,10 +833,10 @@ def _get_encoder_outputs( encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) + generation_config = copy.deepcopy(model.generation_config) + generation_config = model._prepare_special_tokens(generation_config) input_ids = input_ids[:, :, 0] - input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor( - [model._get_decoder_start_token_id()], device=input_ids.device - ) + input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id attention_mask = None return encoder_outputs, input_ids, attention_mask From 0f6ac580342c4ffed0ce7bd49c7af30f83ddd0ef Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 May 2024 10:08:29 +0000 Subject: [PATCH 3/9] missing not --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 89e69cb24643..be8b8136f46c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1171,7 +1171,7 @@ def __init__( self._validate_arguments() # Filter EOS token from bad_words_ids - if eos_token_id is None: + if eos_token_id is not None: if not isinstance(eos_token_id, torch.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] From 64a3c9908fe63e8dee73cc5dc5d47d43dbbeb488 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 May 2024 11:16:23 +0000 Subject: [PATCH 4/9] [test_all] final test fixes --- src/transformers/generation/beam_search.py | 40 +++++++++++-------- src/transformers/generation/logits_process.py | 2 + src/transformers/generation/utils.py | 26 +++++------- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 5e73862e163d..7e4f13ad1e15 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -218,8 +218,8 @@ def process( next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, beam_indices: Optional[torch.LongTensor] = None, group_index: Optional[int] = 0, decoder_prompt_len: Optional[int] = 0, @@ -245,8 +245,10 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) for batch_idx in range(batch_size): batch_group_idx = batch_idx * self.num_beam_groups + group_index @@ -322,15 +324,17 @@ def finalize( final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, beam_indices: Optional[torch.LongTensor] = None, decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) // self.num_beam_groups - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) # finalize all open beam hypotheses and add to generated hypotheses for batch_group_idx, beam_hyp in enumerate(self._beam_hyps): @@ -513,8 +517,8 @@ def process( next_tokens: torch.LongTensor, next_indices: torch.LongTensor, scores_for_all_vocab: torch.FloatTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, beam_indices: Optional[torch.LongTensor] = None, decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.Tensor]: @@ -578,8 +582,10 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -811,15 +817,17 @@ def finalize( final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, beam_indices: Optional[torch.LongTensor] = None, decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index be8b8136f46c..0a7b1ea2b811 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2226,6 +2226,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id = torch.tensor(eos_token_id) + self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") @@ -2237,6 +2238,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores + self.eos_token_id = self.eos_token_id.to(scores.device) if self.min_eos_p: probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 808741db05df..ac9d03d0caf6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1341,13 +1341,17 @@ def _prepare_special_tokens( """ Prepares the special tokens for generation, overwriting the generation config with their processed versions converted to tensor. + Note that `generation_config` is changed in place and stops being serializable after this method is called. - If called outside `generate`, consider creating a copy of `generation_config` first. + That is no problem is callen within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. """ # Convert special tokens to tensors (if they exist) def _tensor_or_none(token): - return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None + if token is None or isinstance(token, torch.Tensor): + return token + return torch.tensor(token, device=self.device, dtype=torch.long) bos_token_id = _tensor_or_none(generation_config.bos_token_id) eos_token_id = _tensor_or_none(generation_config.eos_token_id) @@ -2652,9 +2656,6 @@ def _beam_search( return_dict_in_generate = generation_config.return_dict_in_generate sequential = generation_config.low_memory - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -2778,7 +2779,7 @@ def _beam_search( next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 next_token_scores, next_tokens = torch.topk( next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) @@ -2922,9 +2923,6 @@ def _beam_sample( output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -3145,9 +3143,6 @@ def _group_beam_search( output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups @@ -3250,7 +3245,7 @@ def _group_beam_search( next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 next_token_scores, next_tokens = torch.topk( next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True ) @@ -3430,9 +3425,6 @@ def _constrained_beam_search( output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams @@ -3522,7 +3514,7 @@ def _constrained_beam_search( next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 next_token_scores, next_tokens = torch.topk( next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) From 20a06aadbc9aec4bdbfa73bbd5e7afd6e718463c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 May 2024 11:47:12 +0000 Subject: [PATCH 5/9] fix musicgen_melody and rag --- .../modeling_musicgen_melody.py | 21 ++++++++++++++++++- src/transformers/models/rag/modeling_rag.py | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 55850e0acf9e..d688ce067a92 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -18,7 +18,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -2454,6 +2454,25 @@ def freeze_text_encoder(self): param.requires_grad = False self.text_encoder._requires_grad = False + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._get_decoder_start_token_id + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + ) -> int: + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + @torch.no_grad() def generate( self, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 16c9671f441b..6accd3a28131 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1458,6 +1458,9 @@ def generate( generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + generation_config = self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) + # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs From 0dcaa3dc78b5cc9e3a27ba9f8f3cc15f7d4b5176 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 May 2024 11:47:51 +0000 Subject: [PATCH 6/9] [test_all] empty commit From 12c53ab14f81d11f57f9d7644f48698f937c6630 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 9 May 2024 10:10:15 +0000 Subject: [PATCH 7/9] PR comments --- src/transformers/generation/logits_process.py | 12 +++--------- src/transformers/generation/utils.py | 18 +++++++++++------- src/transformers/models/rag/modeling_rag.py | 2 +- tests/generation/test_utils.py | 2 +- .../seamless_m4t/test_modeling_seamless_m4t.py | 2 +- .../test_modeling_seamless_m4t_v2.py | 2 +- .../test_modeling_speech_to_text.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 2 +- 8 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0a7b1ea2b811..51df91d99735 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -146,9 +146,6 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Te eos_token_id = [eos_token_id] eos_token_id = torch.tensor(eos_token_id) - if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): - logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") - self.min_length = min_length self.eos_token_id = eos_token_id @@ -213,9 +210,6 @@ def __init__( eos_token_id = [eos_token_id] eos_token_id = torch.tensor(eos_token_id) - if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): - logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") - self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens self.eos_token_id = eos_token_id @@ -1492,7 +1486,7 @@ def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Te self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): - logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -1610,7 +1604,7 @@ def __init__( self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): - logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -2229,7 +2223,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: self.eos_token_id = eos_token_id if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): - logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") if min_eos_p is not None and min_eos_p <= 0: raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ac9d03d0caf6..61ea84171a9e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1337,7 +1337,7 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa def _prepare_special_tokens( self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None - ) -> Tuple[Optional[torch.Tensor]]: + ): """ Prepares the special tokens for generation, overwriting the generation config with their processed versions converted to tensor. @@ -1373,20 +1373,23 @@ def _tensor_or_none(token): pad_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") - # Sanity checks + # Sanity checks/warnings if self.config.is_encoder_decoder and decoder_start_token_id is None: raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) + if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not " + "stop until the maximum length is reached. Depending on other flags, it may even crash." + ) - # Update generation config with the updated special tokens + # Update generation config with the updated special tokens tensors generation_config.bos_token_id = bos_token_id generation_config.eos_token_id = eos_token_id generation_config.pad_token_id = pad_token_id generation_config.decoder_start_token_id = decoder_start_token_id - return generation_config - @torch.no_grad() def generate( self, @@ -1504,7 +1507,7 @@ def generate( accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - generation_config = self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( @@ -1512,12 +1515,13 @@ def generate( ) batch_size = inputs_tensor.shape[0] - # decoder-only models must use left-padding for generation. + # decoder-only models must use left-padding for batched generation. if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if ( generation_config.pad_token_id is not None + and batch_size > 1 and len(inputs_tensor.shape) == 2 and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 ): diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 6accd3a28131..7eac28ca77e9 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1459,7 +1459,7 @@ def generate( model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - generation_config = self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask) # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fe5e94310e04..353939833431 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -170,7 +170,7 @@ def _get_encoder_outputs( num_interleave, dim=0 ) generation_config = copy.deepcopy(model.generation_config) - generation_config = model._prepare_special_tokens(generation_config) + model._prepare_special_tokens(generation_config) input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index b286f2c61cd3..925d7342931b 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -415,7 +415,7 @@ def _get_encoder_outputs( num_interleave, dim=0 ) generation_config = copy.deepcopy(model.generation_config) - generation_config = model._prepare_special_tokens(generation_config) + model._prepare_special_tokens(generation_config) input_ids = ( torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) + generation_config.decoder_start_token_id diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 47d8ff91164c..b36e29d79260 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -431,7 +431,7 @@ def _get_encoder_outputs( num_interleave, dim=0 ) generation_config = copy.deepcopy(model.generation_config) - generation_config = model._prepare_special_tokens(generation_config) + model._prepare_special_tokens(generation_config) input_ids = ( torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) + generation_config.decoder_start_token_id diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index b5dbed1d3de8..5d0e8f3a07af 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -646,7 +646,7 @@ def _get_encoder_outputs( ) input_ids = input_ids[:, :, 0] generation_config = copy.deepcopy(model.generation_config) - generation_config = model._prepare_special_tokens(generation_config) + model._prepare_special_tokens(generation_config) input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id attention_mask = None return encoder_outputs, input_ids, attention_mask diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0c45b18f0547..32b13bd5425f 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -834,7 +834,7 @@ def _get_encoder_outputs( num_interleave, dim=0 ) generation_config = copy.deepcopy(model.generation_config) - generation_config = model._prepare_special_tokens(generation_config) + model._prepare_special_tokens(generation_config) input_ids = input_ids[:, :, 0] input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id attention_mask = None From 3069c7ffcabb7947f3553968317c2e61651f29d9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 9 May 2024 11:10:37 +0100 Subject: [PATCH 8/9] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 61ea84171a9e..acacc8cd551a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1343,7 +1343,7 @@ def _prepare_special_tokens( converted to tensor. Note that `generation_config` is changed in place and stops being serializable after this method is called. - That is no problem is callen within `generate` (`generation_config` is a local copy that doesn't leave the + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the function). However, if called outside `generate`, consider creating a copy of `generation_config` first. """ From 19644da4d5ff412aa059e889bb2e92555c7b6675 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 9 May 2024 10:22:02 +0000 Subject: [PATCH 9/9] derp --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index acacc8cd551a..a37cde9df37c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1378,7 +1378,7 @@ def _tensor_or_none(token): raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) - if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): + if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()): logger.warning( f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not " "stop until the maximum length is reached. Depending on other flags, it may even crash."