Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.

Examples:

Expand Down Expand Up @@ -137,22 +139,21 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)

self.min_length = min_length
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
Expand All @@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.

Examples:

Expand All @@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
"""

def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
self,
prompt_length_to_skip: int,
min_new_tokens: int,
eos_token_id: Union[int, List[int], torch.Tensor],
device: str = "cpu",
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
Expand All @@ -208,7 +215,7 @@ def __init__(
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)

self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
Expand All @@ -219,7 +226,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
Expand Down Expand Up @@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.

Examples:
```python
Expand Down Expand Up @@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
```
"""

def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
def __init__(
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
Expand All @@ -817,13 +827,12 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)

self.epsilon = torch.tensor(epsilon)
self.epsilon = torch.tensor(epsilon, device=device)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
Expand Down Expand Up @@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.

Examples:

Expand All @@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
self.max_length = max_length

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
Expand All @@ -1568,7 +1579,6 @@ def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Te
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
Expand Down Expand Up @@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
self.begin_index = begin_index

def set_begin_index(self, begin_index):
Expand All @@ -1780,7 +1790,6 @@ def set_begin_index(self, begin_index):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
Expand Down Expand Up @@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, suppress_tokens):
self.suppress_tokens = torch.tensor(list(suppress_tokens))
def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores
Expand Down Expand Up @@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
"""

def __init__(
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
self,
generate_config,
begin_index: Optional[int] = None,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
Expand Down Expand Up @@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
Minimum end of speech threshold.
"""

def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
Expand All @@ -2309,7 +2320,6 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
Expand Down
74 changes: 60 additions & 14 deletions src/transformers/generation/utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be use self.device? or lm_head.device? (which is not always there but still)

Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def _get_candidate_generator(
def _get_logits_warper(
self,
generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
Expand Down Expand Up @@ -765,7 +766,9 @@ def _get_logits_warper(
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
Expand Down Expand Up @@ -818,7 +821,8 @@ def _get_logits_processor(
):
processors.append(
EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
penalty=generation_config.encoder_repetition_penalty,
encoder_input_ids=encoder_input_ids,
)
)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
Expand All @@ -830,39 +834,63 @@ def _get_logits_processor(
and generation_config.encoder_no_repeat_ngram_size > 0
):
processors.append(
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size,
encoder_input_ids,
)
)
if generation_config.bad_words_ids is not None:
processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > 0
):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
device=device,
)
)
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
device=device,
)
)
if prefix_allowed_tokens_fn is not None:
processors.append(
PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
prefix_allowed_tokens_fn,
generation_config.num_beams // generation_config.num_beam_groups,
)
)
if generation_config.forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
processors.append(
ForcedBOSTokenLogitsProcessor(
generation_config.forced_bos_token_id,
)
)
if generation_config.forced_eos_token_id is not None:
processors.append(
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id,
device=device,
)
)
if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
Expand All @@ -875,7 +903,12 @@ def _get_logits_processor(
)
)
if generation_config.suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
processors.append(
SuppressTokensLogitsProcessor(
generation_config.suppress_tokens,
device=device,
)
)
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
Expand All @@ -887,7 +920,11 @@ def _get_logits_processor(
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens,
begin_index,
device=device,
)
)
if generation_config.forced_decoder_ids is not None:
# TODO(Sanchit): deprecate in v4.40 by removing this logic
Expand Down Expand Up @@ -1779,7 +1816,12 @@ def generate(

# 12. prepare logits warper (if `do_sample` is `True`)
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(
generation_config,
device=input_ids.device,
)
if generation_config.do_sample
else None
)

# 13. run assisted generate
Expand Down Expand Up @@ -1812,7 +1854,9 @@ def generate(
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
Expand All @@ -1838,7 +1882,9 @@ def generate(
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)

# 12. prepare beam search scorer
Expand Down
Loading