diff --git a/.github/workflows/retry.yml b/.github/workflows/retry.yml index 8489186046..b64b5b1f99 100644 --- a/.github/workflows/retry.yml +++ b/.github/workflows/retry.yml @@ -18,7 +18,7 @@ jobs: echo "event: ${{ github.event.workflow_run.conclusion }}" echo "event: ${{ github.event.workflow_run.event }}" - name: Rerun Failed Workflows - if: github.event.workflow_run.conclusion == 'failure' && github.event.run_attempt <= 3 + if: github.event.workflow_run.conclusion == 'failure' && github.event.workflow_run.run_attempt <= 3 env: GH_TOKEN: ${{ github.token }} RUN_ID: ${{ github.event.workflow_run.id }} diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 0a9a427708..2470ae0c16 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ... @typing.overload def _is_tuple( - inputs: TensorOrTupleOfTensorsGeneric, -) -> bool: ... # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, # type: ignore +) -> bool: ... def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 5381350033..10a2385611 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,7 +2,18 @@ # pyre-strict -from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union +from collections import UserDict +from typing import ( + List, + Literal, + Optional, + overload, + Protocol, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) from torch import Tensor from torch.nn import Module @@ -14,7 +25,8 @@ TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] -BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]] +BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]] +BaselineType = Union[None, Tensor, int, float, BaselineTupleType] TensorLikeList1D = List[float] TensorLikeList2D = List[TensorLikeList1D] @@ -30,17 +42,35 @@ ] +# Necessary for Python >=3.7 and <3.9! +if TYPE_CHECKING: + BatchEncodingType = UserDict[Union[int, str], object] +else: + BatchEncodingType = UserDict + + class TokenizerLike(Protocol): """A protocol for tokenizer-like objects that can be used with Captum LLM attribution methods.""" @overload - def encode(self, text: str, return_tensors: None = None) -> List[int]: ... + def encode( + self, text: str, add_special_tokens: bool = ..., return_tensors: None = ... + ) -> List[int]: ... + @overload - def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... + def encode( + self, + text: str, + add_special_tokens: bool = ..., + return_tensors: Literal["pt"] = ..., + ) -> Tensor: ... def encode( - self, text: str, return_tensors: Optional[str] = None + self, + text: str, + add_special_tokens: bool = True, + return_tensors: Optional[str] = None, ) -> Union[List[int], Tensor]: ... def decode(self, token_ids: Tensor) -> str: ... @@ -62,3 +92,10 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ... def convert_tokens_to_ids( self, tokens: Union[List[str], str] ) -> Union[List[int], int]: ... + + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + add_special_tokens: bool = True, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: ... diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 4eb46650b6..abdb7e53f9 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,7 @@ # pyre-strict import math -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Generator, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -573,7 +573,7 @@ def _attribute_progress_setup( formatted_inputs: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], perturbations_per_eval: int, - **kwargs: Dict[str, Any], + **kwargs: Any, ): feature_counts = self._get_feature_counts( formatted_inputs, feature_mask, **kwargs diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 8f66b07487..f0ef2572bb 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -1,9 +1,14 @@ # pyre-strict + +import warnings + +from abc import ABC + from copy import copy from textwrap import shorten -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import matplotlib.colors as mcolors @@ -216,7 +221,47 @@ def plot_seq_attr( return fig, ax -def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]: +def _clean_up_pretty_token(token: str) -> str: + """Remove newlines and leading/trailing whitespace from token.""" + return token.replace("\n", "\\n").strip() + + +def _encode_with_offsets( + txt: str, + tokenizer: TokenizerLike, + add_special_tokens: bool = True, + **kwargs: Any, +) -> Tuple[List[int], List[Tuple[int, int]]]: + enc = tokenizer( + txt, + return_offsets_mapping=True, + add_special_tokens=add_special_tokens, + **kwargs, + ) + input_ids = cast(List[int], enc["input_ids"]) + offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"]) + assert len(input_ids) == len(offset_mapping), ( + f"{len(input_ids)} != {len(offset_mapping)}: {txt} -> " + f"{input_ids}, {offset_mapping}" + ) + # For the case where offsets are not set properly (the end and start are + # equal for all tokens - fall back on the start of the next span in the + # offset mapping) + offset_mapping_corrected = [] + for i, (start, end) in enumerate(offset_mapping): + if start == end: + if (i + 1) < len(offset_mapping): + end = offset_mapping[i + 1][0] + else: + end = len(txt) + offset_mapping_corrected.append((start, end)) + return input_ids, offset_mapping_corrected + + +def _convert_ids_to_pretty_tokens( + ids: Tensor, + tokenizer: TokenizerLike, +) -> List[str]: """ Convert ids to tokens without ugly unicode characters (e.g., ฤ ). See: https://github.com/huggingface/transformers/issues/4786 and @@ -230,10 +275,57 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List > BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm > used spaces in its process """ + txt = tokenizer.decode(ids) + input_ids: Optional[List[int]] = None + # Don't add special tokens (they're either already there, or we don't want them) + input_ids, offset_mapping = _encode_with_offsets( + txt, tokenizer, add_special_tokens=False + ) + + pretty_tokens = [] + end_prev = -1 + idx = 0 + for i, offset in enumerate(offset_mapping): + start, end = offset + if input_ids[i] != ids[idx]: + # When the re-encoded string doesn't match the original encoding we skip + # this token and hope for the best, falling back on a naive method. This + # can happen when a tokenizer might add a token that corresponds to + # a space only when add_special_tokens=False. + warnings.warn( + f"(i={i}, idx={idx}) input_ids[i] {input_ids[i]} != ids[idx] " + f"{ids[idx]} (corresponding to text: {repr(txt[start:end])}). " + "Skipping this token.", + stacklevel=2, + ) + continue + pretty_tokens.append( + _clean_up_pretty_token(txt[start:end]) + + (" [OVERLAP]" if end_prev > start else "") + ) + end_prev = end + idx += 1 + if len(pretty_tokens) != len(ids): + warnings.warn( + f"Pretty tokens length {len(pretty_tokens)} != ids length {len(ids)}! " + "Falling back to naive decoding logic.", + stacklevel=2, + ) + return _convert_ids_to_pretty_tokens_fallback(ids, tokenizer) + return pretty_tokens + + +def _convert_ids_to_pretty_tokens_fallback( + ids: Tensor, tokenizer: TokenizerLike +) -> List[str]: + """ + Fallback function that naively handles logic when multiple ids map to one string. + """ pretty_tokens = [] idx = 0 while idx < len(ids): decoded = tokenizer.decode(ids[idx]) + decoded_pretty = _clean_up_pretty_token(decoded) # Handle case where single token (e.g. unicode) is split into multiple IDs # NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs if decoded.strip() == "๏ฟฝ" and tokenizer.encode(decoded) != [ids[idx]]: @@ -244,21 +336,118 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List ]: # Both tokens are from a split, combine them decoded = tokenizer.decode(ids[idx : idx + 2]) - pretty_tokens.append(decoded + "[1/2]") - pretty_tokens.append(decoded + "[2/2]") + pretty_tokens.append(decoded_pretty) + pretty_tokens.append(decoded_pretty + " [OVERLAP]") else: # Treat tokens as separate - pretty_tokens.append(decoded) - pretty_tokens.append(decoded_next) + pretty_tokens.append(decoded_pretty) + pretty_tokens.append(_clean_up_pretty_token(decoded_next)) idx += 2 else: # Just a normal token idx += 1 - pretty_tokens.append(decoded) + pretty_tokens.append(decoded_pretty) return pretty_tokens -class LLMAttribution(Attribution): +class BaseLLMAttribution(Attribution, ABC): + """Base class for LLM Attribution methods""" + + SUPPORTED_INPUTS: Tuple[Type[InterpretableInput], ...] + SUPPORTED_METHODS: Tuple[Type[Attribution], ...] + + model: nn.Module + tokenizer: TokenizerLike + device: torch.device + + def __init__( + self, + attr_method: Attribution, + tokenizer: TokenizerLike, + ) -> None: + assert isinstance( + attr_method, self.SUPPORTED_METHODS + ), f"{self.__class__.__name__} does not support {type(attr_method)}" + + super().__init__(attr_method.forward_func) + + # alias, we really need a model and don't support wrapper functions + # coz we need call model.forward, model.generate, etc. + self.model: nn.Module = cast(nn.Module, self.forward_func) + + self.tokenizer: TokenizerLike = tokenizer + self.device: torch.device = ( + cast(torch.device, self.model.device) + if hasattr(self.model, "device") + else next(self.model.parameters()).device + ) + + def _get_target_tokens( + self, + inp: InterpretableInput, + target: Union[str, torch.Tensor, None] = None, + skip_tokens: Union[List[int], List[str], None] = None, + gen_args: Optional[Dict[str, Any]] = None, + ) -> Tensor: + assert isinstance( + inp, self.SUPPORTED_INPUTS + ), f"LLMAttribution does not support input type {type(inp)}" + + if target is None: + # generate when None + assert hasattr(self.model, "generate") and callable(self.model.generate), ( + "The model does not have recognizable generate function." + "Target must be given for attribution" + ) + + if not gen_args: + gen_args = DEFAULT_GEN_ARGS + + model_inp = self._format_model_input(inp.to_model_input()) + output_tokens = self.model.generate(model_inp, **gen_args) + target_tokens = output_tokens[0][model_inp.size(1) :] + else: + assert gen_args is None, "gen_args must be None when target is given" + # Encode skip tokens + if skip_tokens: + if isinstance(skip_tokens[0], str): + skip_tokens = cast(List[str], skip_tokens) + skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens) + else: + skip_tokens = [] + skip_tokens = cast(List[int], skip_tokens) + + if isinstance(target, str): + encoded = self.tokenizer.encode(target) + target_tokens = torch.tensor( + [token for token in encoded if token not in skip_tokens] + ) + elif isinstance(target, torch.Tensor): + target_tokens = target[ + ~torch.isin(target, torch.tensor(skip_tokens, device=target.device)) + ] + else: + raise TypeError( + "target must either be str or Tensor, but the type of target is " + "{}".format(type(target)) + ) + return target_tokens + + def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: + """ + Convert str to tokenized tensor + to make LLMAttribution work with model inputs of both + raw text and text token tensors + """ + # return tensor(1, n_tokens) + if isinstance(model_input, str): + return self.tokenizer.encode(model_input, return_tensors="pt").to( + self.device + ) + return model_input.to(self.device) + + +class LLMAttribution(BaseLLMAttribution): """ Attribution class for large language models. It wraps a perturbation-based attribution algorthm to produce commonly interested attribution @@ -304,11 +493,7 @@ class created with the llm model that follows huggingface style Default: "log_prob" """ - assert isinstance( - attr_method, self.SUPPORTED_METHODS - ), f"LLMAttribution does not support {type(attr_method)}" - - super().__init__(attr_method.forward_func) + super().__init__(attr_method, tokenizer) # shallow copy is enough to avoid modifying original instance self.attr_method: PerturbationAttribution = copy(attr_method) @@ -318,17 +503,6 @@ class created with the llm model that follows huggingface style self.attr_method.forward_func = self._forward_func - # alias, we really need a model and don't support wrapper functions - # coz we need call model.forward, model.generate, etc. - self.model: nn.Module = cast(nn.Module, self.forward_func) - - self.tokenizer: TokenizerLike = tokenizer - self.device: torch.device = ( - cast(torch.device, self.model.device) - if hasattr(self.model, "device") - else next(self.model.parameters()).device - ) - assert attr_target in ( "log_prob", "prob", @@ -427,19 +601,6 @@ def _forward_func( return target_probs if self.attr_target != "log_prob" else target_log_probs - def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: - """ - Convert str to tokenized tensor - to make LLMAttribution work with model inputs of both - raw text and text token tensors - """ - # return tensor(1, n_tokens) - if isinstance(model_input, str): - return self.tokenizer.encode(model_input, return_tensors="pt").to( - self.device - ) - return model_input.to(self.device) - def attribute( self, inp: InterpretableInput, @@ -466,7 +627,7 @@ def attribute( of integers of the token ids. Default: None num_trials (int, optional): number of trials to run. Return is the average - attribibutions over all the trials. + attributions over all the trials. Defaults: 1. gen_args (dict, optional): arguments for generating the target. Only used if target is not given. When None, the default arguments are used, @@ -481,49 +642,12 @@ def attribute( attr (LLMAttributionResult): Attribution result. token_attr will be None if attr method is Lime or KernelShap. """ - - assert isinstance( - inp, self.SUPPORTED_INPUTS - ), f"LLMAttribution does not support input type {type(inp)}" - - if target is None: - # generate when None - assert hasattr(self.model, "generate") and callable(self.model.generate), ( - "The model does not have recognizable generate function." - "Target must be given for attribution" - ) - - if not gen_args: - gen_args = DEFAULT_GEN_ARGS - - model_inp = self._format_model_input(inp.to_model_input()) - output_tokens = self.model.generate(model_inp, **gen_args) - target_tokens = output_tokens[0][model_inp.size(1) :] - else: - assert gen_args is None, "gen_args must be None when target is given" - # Encode skip tokens - if skip_tokens: - if isinstance(skip_tokens[0], str): - skip_tokens = cast(List[str], skip_tokens) - skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens) - else: - skip_tokens = [] - skip_tokens = cast(List[int], skip_tokens) - - if isinstance(target, str): - encoded = self.tokenizer.encode(target) - target_tokens = torch.tensor( - [token for token in encoded if token not in skip_tokens] - ) - elif isinstance(target, torch.Tensor): - target_tokens = target[ - ~torch.isin(target, torch.tensor(skip_tokens, device=target.device)) - ] - else: - raise TypeError( - "target must either be str or Tensor, but the type of target is " - "{}".format(type(target)) - ) + target_tokens = self._get_target_tokens( + inp, + target, + skip_tokens=skip_tokens, + gen_args=gen_args, + ) attr = torch.zeros( [ @@ -577,7 +701,7 @@ def attribute_future(self) -> Callable[[], LLMAttributionResult]: ) -class LLMGradientAttribution(Attribution): +class LLMGradientAttribution(BaseLLMAttribution): """ Attribution class for large language models. It wraps a gradient-based attribution algorthm to produce commonly interested attribution @@ -609,37 +733,12 @@ class created with the llm model that follows huggingface style interface convention tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method """ - assert isinstance( - attr_method, self.SUPPORTED_METHODS - ), f"LLMGradientAttribution does not support {type(attr_method)}" - - super().__init__(attr_method.forward_func) - - # alias, we really need a model and don't support wrapper functions - # coz we need call model.forward, model.generate, etc. - self.model: nn.Module = cast(nn.Module, self.forward_func) + super().__init__(attr_method, tokenizer) # shallow copy is enough to avoid modifying original instance self.attr_method: GradientAttribution = copy(attr_method) self.attr_method.forward_func = GradientForwardFunc(self) - self.tokenizer: TokenizerLike = tokenizer - self.device: torch.device = ( - cast(torch.device, self.model.device) - if hasattr(self.model, "device") - else next(self.model.parameters()).device - ) - - def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor: - """ - Convert str to tokenized tensor - """ - if isinstance(model_input, str): - return self.tokenizer.encode(model_input, return_tensors="pt").to( - self.device - ) - return model_input.to(self.device) - def attribute( self, inp: InterpretableInput, @@ -673,50 +772,12 @@ def attribute( attr (LLMAttributionResult): attribution result """ - - assert isinstance( - inp, self.SUPPORTED_INPUTS - ), f"LLMGradAttribution does not support input type {type(inp)}" - - if target is None: - # generate when None - assert hasattr(self.model, "generate") and callable(self.model.generate), ( - "The model does not have recognizable generate function." - "Target must be given for attribution" - ) - - if not gen_args: - gen_args = DEFAULT_GEN_ARGS - - with torch.no_grad(): - model_inp = self._format_model_input(inp.to_model_input()) - output_tokens = self.model.generate(model_inp, **gen_args) - target_tokens = output_tokens[0][model_inp.size(1) :] - else: - assert gen_args is None, "gen_args must be None when target is given" - # Encode skip tokens - if skip_tokens: - if isinstance(skip_tokens[0], str): - skip_tokens = cast(List[str], skip_tokens) - skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens) - else: - skip_tokens = [] - skip_tokens = cast(List[int], skip_tokens) - - if isinstance(target, str): - encoded = self.tokenizer.encode(target) - target_tokens = torch.tensor( - [token for token in encoded if token not in skip_tokens] - ) - elif isinstance(target, torch.Tensor): - target_tokens = target[ - ~torch.isin(target, torch.tensor(skip_tokens, device=target.device)) - ] - else: - raise TypeError( - "target must either be str or Tensor, but the type of target is " - "{}".format(type(target)) - ) + target_tokens = self._get_target_tokens( + inp, + target, + skip_tokens=skip_tokens, + gen_args=gen_args, + ) attr_inp = inp.to_tensor().to(self.device) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 46ee2479ba..3d5f0566f2 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -104,6 +104,9 @@ class to create other types of customized input. is only allowed in certain attribution classes like LLMAttribution for now.) """ + n_itp_features: int + values: List[str] + @abstractmethod def to_tensor(self) -> Tensor: """ diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 06e4651c2f..508fe3a639 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -3,7 +3,7 @@ # pyre-strict import warnings from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union import matplotlib @@ -444,7 +444,7 @@ def visualize_image_attr_multiple( fig_size: Tuple[int, int] = (8, 6), use_pyplot: bool = True, **kwargs: Any, -) -> Tuple[Figure, Axes]: +) -> Tuple[Figure, Union[Axes, List[Axes]]]: r""" Visualizes attribution using multiple visualization methods displayed in a 1 x k grid, where k is the number of desired visualizations. @@ -516,15 +516,19 @@ def visualize_image_attr_multiple( plt_fig = plt.figure(figsize=fig_size) else: plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots(1, len(methods)) + plt_axis_np = plt_fig.subplots(1, len(methods), squeeze=True) + plt_axis: Union[Axes, List[Axes]] plt_axis_list: List[Axes] = [] # When visualizing one if len(methods) == 1: - plt_axis_list = [plt_axis] # type: ignore + plt_axis = cast(Axes, plt_axis_np) + plt_axis_list = [plt_axis] # Figure.subplots returns Axes or array of Axes else: - plt_axis_list = plt_axis # type: ignore + # https://github.com/numpy/numpy/issues/24738 + plt_axis = cast(List[Axes], cast(npt.NDArray, plt_axis_np).tolist()) + plt_axis_list = plt_axis # Figure.subplots returns Axes or array of Axes for i in range(len(methods)): diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index 8963f168fc..83e7705924 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Any, Callable, cast, Optional, Tuple, Union +from typing import Callable, cast, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -15,40 +15,59 @@ ExpansionTypes, safe_div, ) -from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import ( + BaselineTupleType, + BaselineType, + TargetType, + TensorOrTupleOfTensorsGeneric, +) from captum.log import log_usage from captum.metrics._utils.batching import _divide_and_aggregate_metrics from torch import Tensor -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable: +def infidelity_perturb_func_decorator( + multiply_by_inputs: bool = True, + # pyre-ignore[34]: The type variable `Variable[TensorOrTupleOfTensorsGeneric + # <: [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` isn't + # present in the function's parameters. +) -> Callable[ + [Callable[..., TensorOrTupleOfTensorsGeneric]], + Callable[ + [TensorOrTupleOfTensorsGeneric, BaselineType], + Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], + ], +]: r"""An auxiliary, decorator function that helps with computing perturbations given perturbed inputs. It can be useful for cases - when `pertub_func` returns only perturbed inputs and we + when `perturb_func` returns only perturbed inputs and we internally compute the perturbations as (input - perturbed_input) / (input - baseline) if - multipy_by_inputs is set to True and + multiply_by_inputs is set to True and (input - perturbed_input) otherwise. - If users decorate their `pertub_func` with - `@infidelity_perturb_func_decorator` function then their `pertub_func` + If users decorate their `perturb_func` with + `@infidelity_perturb_func_decorator` function then their `perturb_func` needs to only return perturbed inputs. Args: - multipy_by_inputs (bool): Indicates whether model inputs' + multiply_by_inputs (bool): Indicates whether model inputs' multiplier is factored in the computation of attribution scores. """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: + def sub_infidelity_perturb_func_decorator( + perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric] + ) -> Callable[ + [TensorOrTupleOfTensorsGeneric, BaselineType], + Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], + ]: r""" Args: - pertub_func(Callable): Input perturbation function that takes inputs + perturb_func(Callable): Input perturbation function that takes inputs and optionally baselines and returns perturbed inputs Returns: @@ -68,23 +87,18 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: """ - # pyre-fixme[3]: Return type must be annotated. def default_perturb_func( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None - ): + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: r""" """ - inputs_perturbed = ( - pertub_func(inputs, baselines) + inputs_perturbed: TensorOrTupleOfTensorsGeneric = ( + perturb_func(inputs, baselines) if baselines is not None - else pertub_func(inputs) + else perturb_func(inputs) ) - inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used - # as `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines = _format_baseline(baselines, inputs) + inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed) + inputs_formatted = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs_formatted) if baselines is None: perturbations = tuple( ( @@ -93,12 +107,12 @@ def default_perturb_func( input, default_denom=1.0, ) - if multipy_by_inputs + if multiply_by_inputs else input - input_perturbed ) - # pyre-fixme[6]: For 2nd argument expected - # `Iterable[Variable[_T2]]` but got `None`. - for input, input_perturbed in zip(inputs, inputs_perturbed) + for input, input_perturbed in zip( + inputs_formatted, inputs_perturbed_formatted + ) ) else: perturbations = tuple( @@ -108,18 +122,16 @@ def default_perturb_func( input - baseline, default_denom=1.0, ) - if multipy_by_inputs + if multiply_by_inputs else input - input_perturbed ) for input, input_perturbed, baseline in zip( - inputs, - # pyre-fixme[6]: For 2nd argument expected - # `Iterable[Variable[_T2]]` but got `None`. - inputs_perturbed, + inputs_formatted, + inputs_perturbed_formatted, baselines, ) ) - return perturbations, inputs_perturbed + return perturbations, inputs_perturbed_formatted return default_perturb_func @@ -128,15 +140,14 @@ def default_perturb_func( @log_usage() def infidelity( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable, + forward_func: Callable[..., Tensor], + perturb_func: Callable[ + ..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ], inputs: TensorOrTupleOfTensorsGeneric, attributions: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, target: TargetType = None, n_perturb_samples: int = 10, max_examples_per_batch: Optional[int] = None, @@ -188,25 +199,25 @@ def infidelity( >>> from captum.metrics import infidelity_perturb_func_decorator - >>> @infidelity_perturb_func_decorator() + >>> @infidelity_perturb_func_decorator() >>> def my_perturb_func(inputs): >>> >>> return perturbed_inputs - In case `multipy_by_inputs` is False we compute perturbations by - `input - perturbed_input` difference and in case `multipy_by_inputs` + In case `multiply_by_inputs` is False we compute perturbations by + `input - perturbed_input` difference and in case `multiply_by_inputs` flag is True we compute it by dividing (input - perturbed_input) by (input - baselines). The user needs to only return perturbed inputs in `perturb_func` as described above. `infidelity_perturb_func_decorator` needs to be used with - `multipy_by_inputs` flag set to False in case infidelity + `multiply_by_inputs` flag set to False in case infidelity score is being computed for attribution maps that are local aka that do not factor in inputs in the final attribution score. Such attribution algorithms include Saliency, GradCam, Guided Backprop, or Integrated Gradients and DeepLift attribution scores that are already - computed with `multipy_by_inputs=False` flag. + computed with `multiply_by_inputs=False` flag. If there are more than one inputs passed to infidelity function those will be passed to `perturb_func` as tuples in the same order as they @@ -283,10 +294,10 @@ def infidelity( meaning that the inputs multiplier isn't factored in the attribution scores. This can be done duing the definition of the attribution algorithm - by passing `multipy_by_inputs=False` flag. + by passing `multiply_by_inputs=False` flag. For example in case of Integrated Gradients (IG) we can obtain local attribution scores if we define the constructor of IG as: - ig = IntegratedGradients(multipy_by_inputs=False) + ig = IntegratedGradients(multiply_by_inputs=False) Some attribution algorithms are inherently local. Examples of inherently local attribution methods include: @@ -409,35 +420,35 @@ def infidelity( >>> infid = infidelity(net, perturb_fn, input, attribution) """ # perform argument formattings - inputs = _format_tensor_into_tuples(inputs) # type: ignore + inputs_formatted = _format_tensor_into_tuples(inputs) + baselines_formatted: BaselineTupleType = None if baselines is not None: - baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs)) + baselines_formatted = _format_baseline(baselines, inputs_formatted) additional_forward_args = _format_additional_forward_args(additional_forward_args) - attributions = _format_tensor_into_tuples(attributions) # type: ignore + attributions_formatted = _format_tensor_into_tuples(attributions) # Make sure that inputs and corresponding attributions have matching sizes. - assert len(inputs) == len(attributions), ( - """The number of tensors in the inputs and - attributions must match. Found number of tensors in the inputs is: {} and in the - attributions: {}""" - ).format(len(inputs), len(attributions)) - for inp, attr in zip(inputs, attributions): + assert len(inputs_formatted) == len(attributions_formatted), ( + "The number of tensors in the inputs and attributions must match. " + f"Found number of tensors in the inputs is: {len(inputs_formatted)} and in " + f"the attributions: {len(attributions_formatted)}" + ) + for inp, attr in zip(inputs_formatted, attributions_formatted): assert inp.shape == attr.shape, ( - """Inputs and attributions must have - matching shapes. One of the input tensor's shape is {} and the - attribution tensor's shape is: {}""" - # pyre-fixme[16]: Module `attr` has no attribute `shape`. - ).format(inp.shape, attr.shape) + "Inputs and attributions must have matching shapes. " + f"One of the input tensor's shape is {inp.shape} and the " + f"attribution tensor's shape is: {attr.shape}" + ) - bsz = inputs[0].size(0) + bsz = inputs_formatted[0].size(0) _next_infidelity_tensors = _make_next_infidelity_tensors_func( forward_func, bsz, perturb_func, - inputs, - baselines, - attributions, + inputs_formatted, + baselines_formatted, + attributions_formatted, additional_forward_args, target, normalize, @@ -447,7 +458,7 @@ def infidelity( # if not normalize, directly return aggrgated MSE ((a-b)^2,) # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) agg_tensors = _divide_and_aggregate_metrics( - cast(Tuple[Tensor, ...], inputs), + inputs_formatted, n_perturb_samples, _next_infidelity_tensors, agg_func=_sum_infidelity_tensors, @@ -461,11 +472,7 @@ def infidelity( beta = safe_div(beta_num, beta_denorm) infidelity_values = ( - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - beta**2 * agg_tensors[0] - - 2 * beta * agg_tensors[1] - + agg_tensors[2] + beta * beta * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2] ) else: infidelity_values = agg_tensors[0] @@ -477,10 +484,11 @@ def infidelity( def _generate_perturbations( current_n_perturb_samples: int, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable, - inputs: TensorOrTupleOfTensorsGeneric, - baselines: BaselineType, + perturb_func: Callable[ + ..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ], + inputs: Tuple[Tensor, ...], + baselines: BaselineTupleType, ) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]: r""" The perturbations are generated for each example @@ -491,17 +499,16 @@ def _generate_perturbations( repeated instances per example. """ - # pyre-fixme[3]: Return type must be annotated. - def call_perturb_func(): + def call_perturb_func() -> ( + Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ): r""" """ - baselines_pert = None + baselines_pert: BaselineType = None inputs_pert: Union[Tensor, Tuple[Tensor, ...]] if len(inputs_expanded) == 1: inputs_pert = inputs_expanded[0] if baselines_expanded is not None: - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type - # parameter. - baselines_pert = cast(Tuple, baselines_expanded)[0] + baselines_pert = baselines_expanded[0] else: inputs_pert = inputs_expanded baselines_pert = baselines_expanded @@ -526,9 +533,7 @@ def call_perturb_func(): and baseline.shape[0] > 1 else baseline ) - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type - # parameter. - for input, baseline in zip(inputs, cast(Tuple, baselines)) + for input, baseline in zip(inputs, baselines) ) return call_perturb_func() @@ -541,33 +546,32 @@ def _validate_inputs_and_perturbations( ) -> None: # asserts the sizes of the perturbations and inputs assert len(perturbations) == len(inputs), ( - """The number of perturbed - inputs and corresponding perturbations must have the same number of - elements. Found number of inputs is: {} and perturbations: - {}""" - ).format(len(perturbations), len(inputs)) + "The number of perturbed " + "inputs and corresponding perturbations must have the same number of " + f"elements. Found number of inputs is: {len(perturbations)} and " + f"perturbations: {len(inputs)}" + ) # asserts the shapes of the perturbations and perturbed inputs for perturb, input_perturbed in zip(perturbations, inputs_perturbed): assert perturb[0].shape == input_perturbed[0].shape, ( - """Perturbed input - and corresponding perturbation must have the same shape and - dimensionality. Found perturbation shape is: {} and the input shape - is: {}""" - ).format(perturb[0].shape, input_perturbed[0].shape) + "Perturbed input " + "and corresponding perturbation must have the same shape and " + f"dimensionality. Found perturbation shape is: {perturb[0].shape} " + f"and the input shape is: {input_perturbed[0].shape}" + ) def _make_next_infidelity_tensors_func( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], bsz: int, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable, - inputs: TensorOrTupleOfTensorsGeneric, - baselines: BaselineType, - attributions: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + perturb_func: Callable[ + ..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ], + inputs: Tuple[Tensor, ...], + baselines: BaselineTupleType, + attributions: Tuple[Tensor, ...], + additional_forward_args: object = None, target: TargetType = None, normalize: bool = False, ) -> Callable[[int], Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]]: @@ -579,15 +583,13 @@ def _next_infidelity_tensors( current_n_perturb_samples, perturb_func, inputs, baselines ) - perturbations = _format_tensor_into_tuples(perturbations) - inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) + perturbations_formatted = _format_tensor_into_tuples(perturbations) + inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed) _validate_inputs_and_perturbations( - cast(Tuple[Tensor, ...], inputs), - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], inputs_perturbed), - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], perturbations), + inputs, + inputs_perturbed_formatted, + perturbations_formatted, ) targets_expanded = _expand_target( @@ -603,7 +605,7 @@ def _next_infidelity_tensors( inputs_perturbed_fwd = _run_forward( forward_func, - inputs_perturbed, + inputs_perturbed_formatted, targets_expanded, additional_forward_args_expanded, ) @@ -624,7 +626,7 @@ def _next_infidelity_tensors( attributions_times_perturb = tuple( (attribution_expanded * perturbation).view(attribution_expanded.size(0), -1) for attribution_expanded, perturbation in zip( - attributions_expanded, perturbations + attributions_expanded, perturbations_formatted ) ) @@ -654,7 +656,7 @@ def _next_infidelity_tensors( return _next_infidelity_tensors -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def _sum_infidelity_tensors(agg_tensors, tensors): +def _sum_infidelity_tensors( + agg_tensors: Tuple[Tensor, ...], tensors: Tuple[Tensor, ...] +) -> Tuple[Tensor, ...]: return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 085813b09f..5b6bb89e0c 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -5,6 +5,7 @@ from typing import List, Literal, Optional, overload, Union import torch +from captum._utils.typing import BatchEncodingType from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized from tests.helpers import BaseTest @@ -19,12 +20,23 @@ def __init__(self, vocab_list) -> None: self.unk_idx = len(vocab_list) + 1 @overload - def encode(self, text: str, return_tensors: None = None) -> List[int]: ... + def encode( + self, text: str, add_special_tokens: bool = ..., return_tensors: None = ... + ) -> List[int]: ... + @overload - def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... + def encode( + self, + text: str, + add_special_tokens: bool = ..., + return_tensors: Literal["pt"] = ..., + ) -> Tensor: ... def encode( - self, text: str, return_tensors: Optional[str] = "pt" + self, + text: str, + add_special_tokens: bool = True, + return_tensors: Optional[str] = "pt", ) -> Union[List[int], Tensor]: assert return_tensors == "pt" return torch.tensor([self.convert_tokens_to_ids(text.split(" "))]) @@ -68,6 +80,14 @@ def convert_tokens_to_ids( def decode(self, token_ids: Tensor) -> str: raise NotImplementedError + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + add_special_tokens: bool = True, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: + raise NotImplementedError + class TestTextTemplateInput(BaseTest): @parameterized.expand( diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index d22bef384b..0bbe4f4e73 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -3,6 +3,8 @@ # pyre-strict import copy + +from collections import UserDict from typing import ( Any, cast, @@ -19,6 +21,7 @@ import torch from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.typing import BatchEncodingType from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap @@ -39,24 +42,38 @@ class DummyTokenizer: vocab_size: int = 256 sos: int = 0 unk: int = 1 - special_tokens: Dict[int, str] = {sos: "", unk: ""} + sos_str: str = "" + special_tokens: Dict[int, str] = {sos: sos_str, unk: ""} @overload - def encode(self, text: str, return_tensors: None = None) -> List[int]: ... + def encode( + self, text: str, add_special_tokens: bool = ..., return_tensors: None = ... + ) -> List[int]: ... + @overload - def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... + def encode( + self, + text: str, + add_special_tokens: bool = ..., + return_tensors: Literal["pt"] = ..., + ) -> Tensor: ... def encode( - self, text: str, return_tensors: Optional[str] = None + self, + text: str, + add_special_tokens: bool = True, + return_tensors: Optional[str] = None, ) -> Union[List[int], Tensor]: tokens = text.split(" ") tokens_ids: Union[List[int], Tensor] = [ - ord(s[0]) if len(s) == 1 else self.unk for s in tokens + ord(s[0]) if len(s) == 1 else (self.sos if s == self.sos_str else self.unk) + for s in tokens ] # start with sos - tokens_ids = [self.sos, *tokens_ids] + if add_special_tokens: + tokens_ids = [self.sos, *tokens_ids] if return_tensors: return torch.tensor([tokens_ids]) @@ -96,6 +113,30 @@ def decode(self, token_ids: Tensor) -> str: # pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`. return tokens if isinstance(tokens, str) else " ".join(tokens) + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + add_special_tokens: bool = True, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: + assert isinstance(text, str) + input_ids = self.encode(text, add_special_tokens=add_special_tokens) + + result: BatchEncodingType = UserDict() + result["input_ids"] = input_ids + + if return_offsets_mapping: + offset_mapping = [] + if add_special_tokens: + offset_mapping.append((0, 0)) + idx = 0 + for token in text.split(" "): + offset_mapping.append((idx - (0 if idx == 0 else 1), idx + len(token))) + idx += len(token) + 1 # +1 for space + result["offset_mapping"] = offset_mapping + + return result + class Result(NamedTuple): logits: Tensor diff --git a/tests/attr/test_llm_attr_hf_compatibility.py b/tests/attr/test_llm_attr_hf_compatibility.py index 9798867866..f465cb0ef2 100644 --- a/tests/attr/test_llm_attr_hf_compatibility.py +++ b/tests/attr/test_llm_attr_hf_compatibility.py @@ -1,11 +1,15 @@ #!/usr/bin/env python3 - +import warnings from typing import cast, Dict, Optional, Type import torch from captum.attr._core.feature_ablation import FeatureAblation -from captum.attr._core.llm_attr import LLMAttribution +from captum.attr._core.llm_attr import ( + _convert_ids_to_pretty_tokens, + _convert_ids_to_pretty_tokens_fallback, + LLMAttribution, +) from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput @@ -15,7 +19,7 @@ HAS_HF = True try: - # pyre-fixme[21]: Could not find a module corresponding to import `transformers` + # pyre-ignore[21]: Could not find a module corresponding to import `transformers` from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: HAS_HF = False @@ -87,3 +91,191 @@ def test_llm_attr_hf_compatibility( self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) self.assertEqual(res.seq_attr.device.type, self.device) self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + +class TestTokenizerHFCompatibility(BaseTest): + def setUp(self) -> None: + if not HAS_HF: + self.skipTest("transformers package not found, skipping tests") + super().setUp() + + @parameterized.expand([(True,), (False,)]) + def test_tokenizer_pretty_print(self, add_special_tokens: bool) -> None: + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + txt = ( + 'One two three\n๐Ÿ˜\n๐Ÿ˜‚\n๐Ÿ˜ธ\n๐Ÿ˜\n๐Ÿ˜‚\n๐Ÿ˜ธ\n๐Ÿ˜\n\'๐Ÿ˜‚\n๐Ÿ˜ธ๐Ÿ˜‚\n๐Ÿ˜๐Ÿ˜๐Ÿ˜๐Ÿ˜๐Ÿ˜\n๐Ÿ˜‚:\n"๐Ÿ˜ธ"\n๐Ÿ˜‚' + "\n๏ฟฝ\n\nเฐฅเฎเซนเงฃเค†ฮ”ฮ˜ฯ–\n" + ) + special_tokens_pretty = [ + "", + "One", + ] + no_special_tokens_pretty = [ + "One", + ] + expected_tokens_tail_pretty = [ + "two", + "three", + "\\n", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "\\n", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "\\n", + "๐Ÿ˜ธ", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "\\n", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "\\n", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "\\n", + "๐Ÿ˜ธ", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "\\n", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "\\n", + "'", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "\\n", + "๐Ÿ˜ธ", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "\\n", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "๐Ÿ˜ [OVERLAP]", + "\\n", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + ":", + "\\n", + '"', + "๐Ÿ˜ธ", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + "๐Ÿ˜ธ [OVERLAP]", + '"', + "\\n", + "๐Ÿ˜‚", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "๐Ÿ˜‚ [OVERLAP]", + "\\n", + "๏ฟฝ", + "\\n", + "\\n", + "เฐฅ", + "เฐฅ [OVERLAP]", + "เฐฅ [OVERLAP]", + "เฎ", + "เฎ [OVERLAP]", + "เฎ [OVERLAP]", + "เซน", + "เซน [OVERLAP]", + "เซน [OVERLAP]", + "เงฃ", + "เงฃ [OVERLAP]", + "เงฃ [OVERLAP]", + "เค†", + "ฮ”", + "ฮ˜", + "ฯ–", + "ฯ– [OVERLAP]", + "\\n", + ] + ids = tokenizer.encode(txt, add_special_tokens=add_special_tokens) + head_pretty = ( + special_tokens_pretty if add_special_tokens else no_special_tokens_pretty + ) + with warnings.catch_warnings(): + if add_special_tokens: + # This particular tokenizer adds a token for the space after when + # we encode the decoded ids in _convert_ids_to_pretty_tokens + warnings.filterwarnings( + "ignore", category=UserWarning, message=".* Skipping this token." + ) + self.assertEqual( + _convert_ids_to_pretty_tokens(ids, tokenizer), + head_pretty + expected_tokens_tail_pretty, + ) + + @parameterized.expand([(True,), (False,)]) + def test_tokenizer_pretty_print_fallback(self, add_special_tokens: bool) -> None: + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + txt = "Running and jumping and climbing:\nMeow meow meow" + ids = tokenizer.encode(txt, add_special_tokens=add_special_tokens) + + special_tokens_pretty = ["", "Running"] + no_special_tokens_pretty = ["Running"] + expected_tokens_tail_pretty = [ + "and", + "jump", + "ing", + "and", + "clim", + "bing", + ":", + "\\n", + "Me", + "ow", + "me", + "ow", + "me", + "ow", + ] + head_pretty = ( + special_tokens_pretty if add_special_tokens else no_special_tokens_pretty + ) + self.assertEqual( + _convert_ids_to_pretty_tokens_fallback(ids, tokenizer), + head_pretty + expected_tokens_tail_pretty, + )