Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/retry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
echo "event: ${{ github.event.workflow_run.conclusion }}"
echo "event: ${{ github.event.workflow_run.event }}"
- name: Rerun Failed Workflows
if: github.event.workflow_run.conclusion == 'failure' && github.event.run_attempt <= 3
if: github.event.workflow_run.conclusion == 'failure' && github.event.workflow_run.run_attempt <= 3
env:
GH_TOKEN: ${{ github.token }}
RUN_ID: ${{ github.event.workflow_run.id }}
Expand Down
4 changes: 2 additions & 2 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ...

@typing.overload
def _is_tuple(
inputs: TensorOrTupleOfTensorsGeneric,
) -> bool: ... # type: ignore
inputs: TensorOrTupleOfTensorsGeneric, # type: ignore
) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
Expand Down
47 changes: 42 additions & 5 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

# pyre-strict

from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
from collections import UserDict
from typing import (
List,
Literal,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

from torch import Tensor
from torch.nn import Module
Expand All @@ -14,7 +25,8 @@
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
BaselineType = Union[None, Tensor, int, float, BaselineTupleType]

TensorLikeList1D = List[float]
TensorLikeList2D = List[TensorLikeList1D]
Expand All @@ -30,17 +42,35 @@
]


# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
BatchEncodingType = UserDict[Union[int, str], object]
else:
BatchEncodingType = UserDict


class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
def encode(
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
) -> List[int]: ...

@overload
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
def encode(
self,
text: str,
add_special_tokens: bool = ...,
return_tensors: Literal["pt"] = ...,
) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = None
self,
text: str,
add_special_tokens: bool = True,
return_tensors: Optional[str] = None,
) -> Union[List[int], Tensor]: ...

def decode(self, token_ids: Tensor) -> str: ...
Expand All @@ -62,3 +92,10 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]: ...

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
add_special_tokens: bool = True,
return_offsets_mapping: bool = False,
) -> BatchEncodingType: ...
4 changes: 2 additions & 2 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict

import math
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -573,7 +573,7 @@ def _attribute_progress_setup(
formatted_inputs: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
perturbations_per_eval: int,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
feature_counts = self._get_feature_counts(
formatted_inputs, feature_mask, **kwargs
Expand Down
Loading
Loading