Skip to content

Commit c1b86a3

Browse files
csauperfacebook-github-bot
authored andcommitted
Fix more pyre errors for llm_attr and tests [2/n]
Summary: Get rid of some more pyre errors by adding better typing Differential Revision: D63365945
1 parent 0af8bd6 commit c1b86a3

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _forward_func(
343343
caching=use_cached_outputs,
344344
)
345345

346-
log_prob_list = []
346+
log_prob_list: List[Tensor] = []
347347
outputs = None
348348
for target_token in target_tokens:
349349
if use_cached_outputs:
@@ -382,17 +382,15 @@ def _forward_func(
382382
(model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1
383383
)
384384

385-
# pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
386-
# Tensor
387385
total_log_prob: Tensor = sum(log_prob_list) # type: ignore
388386
# 1st element is the total prob, rest are the target tokens
389387
# add a leading dim for batch even we only support single instance for now
390388
if self.include_per_token_attr:
391389
target_log_probs = torch.stack(
392-
[total_log_prob, *log_prob_list], dim=0 # type: ignore
390+
[total_log_prob, *log_prob_list], dim=0
393391
).unsqueeze(0)
394392
else:
395-
target_log_probs = total_log_prob # type: ignore
393+
target_log_probs = total_log_prob
396394
target_probs = torch.exp(target_log_probs)
397395

398396
if _inspect_forward:
@@ -544,8 +542,7 @@ def attribute(
544542
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
545543
)
546544

547-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
548-
def attribute_future(self) -> Callable:
545+
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
549546
r"""
550547
This method is not implemented for LLMAttribution.
551548
"""
@@ -745,8 +742,7 @@ def attribute(
745742
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
746743
)
747744

748-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
749-
def attribute_future(self) -> Callable:
745+
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
750746
r"""
751747
This method is not implemented for LLMGradientAttribution.
752748
"""

tests/attr/test_llm_attr.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
cast,
99
Dict,
1010
List,
11+
Literal,
1112
NamedTuple,
1213
Optional,
1314
overload,
@@ -18,7 +19,6 @@
1819

1920
import torch
2021
from captum._utils.models.linear_model import SkLearnLasso
21-
from captum._utils.typing import Literal
2222
from captum.attr._core.feature_ablation import FeatureAblation
2323
from captum.attr._core.kernel_shap import KernelShap
2424
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
@@ -44,9 +44,6 @@ class DummyTokenizer:
4444
@overload
4545
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
4646
@overload
47-
# pyre-fixme[43]: Incompatible overload. The implementation of
48-
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
49-
# pyre-ignore[11]: Annotation `pt` is not defined as a type
5047
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
5148

5249
def encode(

tests/attr/test_llm_attr_gpu.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,21 @@
33
# pyre-strict
44

55
import copy
6-
from typing import Any, cast, Dict, List, NamedTuple, Optional, overload, Type, Union
6+
from typing import (
7+
Any,
8+
cast,
9+
Dict,
10+
List,
11+
Literal,
12+
NamedTuple,
13+
Optional,
14+
overload,
15+
Type,
16+
Union,
17+
)
718

819
import torch
920

10-
from captum._utils.typing import Literal
1121
from captum.attr._core.feature_ablation import FeatureAblation
1222
from captum.attr._core.kernel_shap import KernelShap
1323
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
@@ -32,9 +42,6 @@ class DummyTokenizer:
3242
@overload
3343
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
3444
@overload
35-
# pyre-fixme[43]: Incompatible overload. The implementation of
36-
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
37-
# pyre-ignore[11]: Annotation `pt` is not defined as a type
3845
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
3946

4047
def encode(

0 commit comments

Comments
 (0)