Skip to content

Token based (or sequence of token based) repetition penalty exclusion #26902

Closed
@teknium1

Description

@teknium1

Feature request

I started this issue in TGI but it applies to all inference code that has a form of rep penalty, will paste my feature request notes from there here as well, find original here: huggingface/text-generation-inference#1170

Hello, I would like to propose a feature that allows you to set a list of tokens, or even token sequences, that can be excluded from repetition penalty calculations.

The reasoning for this being that, given a prompt format for multiturn, such as:

user: abc
assistant: def
user: ghi
assistant: jkl

Or even worse, a format like ChatML, where it is in now standard case using <|im_end|> as a stopping token and included in every turn, it seems only logical that given these tokens all appear in every turn, that, especially in short token turn sequences, repetition penalty will destroy the validity of these prompt formats.

While I havent noticed this using Hermes 2, it may be solely because it has long responses, where, if avg turn length is very few tokens, the problem may become more prominent.

Motivation

image

Your contribution

The following is TGI's code, I haven't looked at transformer's code, but I assume the principles are the same:

class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
    This version allows for a separate value for each sample and runs inplace when possible.
    It doesn't validate inputs.

    Args:
        repetition_penalty (`List[float]`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
        self.penalty = penalty
        self.penalty_tensor = torch.tensor(
            penalty, dtype=dtype, device=device
        ).unsqueeze(1)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(
            score < 0, score * self.penalty_tensor, score / self.penalty_tensor
        )

        scores.scatter_(1, input_ids, score)
        return scores

    def filter(self, indices):
        self.penalty = [self.penalty[i] for i in indices]
        if any([x != 1.0 for x in self.penalty]):
            self.penalty_tensor = self.penalty_tensor[indices]
            return self
        return None

I'm thinking we take the input_id's before getting scored and simply replacing it with input_id's that remove any from some variable setting a list of token ids or token strings->id's.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions