Skip to content

Commit cd0d115

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Add __call__ to TokenizerLike (#1418)
Summary: Add __call__ to TokenizerLike for transformers compatibility Differential Revision: D64998805
1 parent cbe45aa commit cd0d115

File tree

4 files changed

+43
-3
lines changed

4 files changed

+43
-3
lines changed

captum/_utils/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ...
9090

9191
@typing.overload
9292
def _is_tuple(
93-
inputs: TensorOrTupleOfTensorsGeneric,
94-
) -> bool: ... # type: ignore
93+
inputs: TensorOrTupleOfTensorsGeneric, # type: ignore
94+
) -> bool: ...
9595

9696

9797
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:

captum/_utils/typing.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22

33
# pyre-strict
44

5-
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
5+
from collections import UserDict
6+
from typing import (
7+
List,
8+
Literal,
9+
Optional,
10+
overload,
11+
Protocol,
12+
Tuple,
13+
TYPE_CHECKING,
14+
TypeVar,
15+
Union,
16+
)
617

718
from torch import Tensor
819
from torch.nn import Module
@@ -30,6 +41,13 @@
3041
]
3142

3243

44+
# Necessary for Python >=3.7 and <3.9!
45+
if TYPE_CHECKING:
46+
BatchEncodingType = UserDict[Union[int, str], object]
47+
else:
48+
BatchEncodingType = UserDict
49+
50+
3351
class TokenizerLike(Protocol):
3452
"""A protocol for tokenizer-like objects that can be used with Captum
3553
LLM attribution methods."""
@@ -62,3 +80,9 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
6280
def convert_tokens_to_ids(
6381
self, tokens: Union[List[str], str]
6482
) -> Union[List[int], int]: ...
83+
84+
def __call__(
85+
self,
86+
text: Optional[Union[str, List[str], List[List[str]]]] = None,
87+
return_offsets_mapping: bool = False,
88+
) -> BatchEncodingType: ...

tests/attr/test_interpretable_input.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Literal, Optional, overload, Union
66

77
import torch
8+
from captum._utils.typing import BatchEncodingType
89
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
910
from parameterized import parameterized
1011
from tests.helpers import BaseTest
@@ -68,6 +69,13 @@ def convert_tokens_to_ids(
6869
def decode(self, token_ids: Tensor) -> str:
6970
raise NotImplementedError
7071

72+
def __call__(
73+
self,
74+
text: Optional[Union[str, List[str], List[List[str]]]] = None,
75+
return_offsets_mapping: bool = False,
76+
) -> BatchEncodingType:
77+
raise NotImplementedError
78+
7179

7280
class TestTextTemplateInput(BaseTest):
7381
@parameterized.expand(

tests/attr/test_llm_attr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from captum._utils.models.linear_model import SkLearnLasso
22+
from captum._utils.typing import BatchEncodingType
2223
from captum.attr._core.feature_ablation import FeatureAblation
2324
from captum.attr._core.kernel_shap import KernelShap
2425
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
@@ -96,6 +97,13 @@ def decode(self, token_ids: Tensor) -> str:
9697
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
9798
return tokens if isinstance(tokens, str) else " ".join(tokens)
9899

100+
def __call__(
101+
self,
102+
text: Optional[Union[str, List[str], List[List[str]]]] = None,
103+
return_offsets_mapping: bool = False,
104+
) -> BatchEncodingType:
105+
raise NotImplementedError
106+
99107

100108
class Result(NamedTuple):
101109
logits: Tensor

0 commit comments

Comments
 (0)