Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 76 additions & 57 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.

Examples:

Expand Down Expand Up @@ -137,23 +137,26 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.min_length = min_length
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
Expand All @@ -171,8 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.

Examples:

Expand All @@ -195,18 +198,23 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
Expand All @@ -217,8 +225,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)

Expand Down Expand Up @@ -1112,8 +1120,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
The id(s) of the *end-of-sequence* token.

Examples:

Expand Down Expand Up @@ -1150,18 +1158,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
```
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
def __init__(
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
):
self.bad_word_ids = bad_words_ids
self._validate_arguments()

# Filter EOS token from bad_words_ids
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
if eos_token_id is not None:
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)

# Forbidding a sequence is equivalent to setting its bias to -inf
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
Expand Down Expand Up @@ -1439,9 +1451,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token to be forced (whitelisted for the current iteration).

Examples:

Expand All @@ -1465,15 +1476,18 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
Expand Down Expand Up @@ -1505,15 +1519,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class ExponentialDecayLengthPenalty(LogitsProcessor):
r"""
[`LogitsProcessor`] that exponentially increases the score of the `eos_token_id` after `start_index` has been
reached. This allows generating shorter sequences without having a hard cutoff, allowing the `eos_token` to be
reached. This allows generating shorter sequences without having a hard cutoff, allowing the `eos_token_id` to be
predicted in a meaningful position.

Args:
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.

Expand Down Expand Up @@ -1573,27 +1587,29 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
def __init__(
self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
eos_token_id: Union[int, List[int], torch.Tensor],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
penalties = torch.zeros_like(scores)
scores_processed = scores
if cur_len > self.regulation_start:
for i in self.eos_token_id:
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, i] = penalty
scores_processed = scores + penalties
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, self.eos_token_id] = penalty
scores_processed = scores + penalties
return scores_processed


Expand Down Expand Up @@ -1670,7 +1686,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
self.begin_index = begin_index

def set_begin_index(self, begin_index):
Expand All @@ -1679,8 +1695,8 @@ def set_begin_index(self, begin_index):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
Expand Down Expand Up @@ -1718,13 +1734,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
"""

def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
self.suppress_tokens = torch.tensor(list(suppress_tokens))

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores

Expand Down Expand Up @@ -2212,16 +2228,19 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
</Tip>

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""

def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,21 @@ class EosTokenCriteria(StoppingCriteria):
By default, it uses the `model.generation_config.eos_token_id`.

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
"""

def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
self.eos_token_id = self.eos_token_id.to(input_ids.device)
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
return is_done


Expand Down
Loading