Skip to content

Commit bd7e429

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Gradient-based LLM attribution in tutorial + LLM attribution type annotations (#1333)
Summary: Pull Request resolved: #1333 Add gradient-based LLM attribution to the tutorial notebook. Addresses #1237. Additionally, add more type annotations to llm_attr.py. Differential Revision: D61461521
1 parent 09aa048 commit bd7e429

File tree

5 files changed

+217
-143
lines changed

5 files changed

+217
-143
lines changed

captum/_utils/typing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
5+
from typing import List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, Union
66

77
from torch import Tensor
88
from torch.nn import Module
@@ -33,3 +33,14 @@
3333
TensorLikeList4D,
3434
TensorLikeList5D,
3535
]
36+
37+
38+
class TokenizerLike(Protocol):
39+
"""A protocol for tokenizer-like objects that can be used with Captum
40+
LLM attribution methods."""
41+
42+
def encode(
43+
self, text: str, return_tensors: Optional[str] = None
44+
) -> Union[List[int], Tensor]: ...
45+
def decode(self, token_ids: Tensor) -> str: ...
46+
def convert_ids_to_tokens(self, token_ids: Tensor) -> List[str]: ...

0 commit comments

Comments
 (0)