Skip to content

Commit 40975a5

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix test sensitivity pyre fix me issues (#1481)
Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: banne01 Differential Revision: D67726420
1 parent 1947050 commit 40975a5

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

tests/helpers/evaluate_linear_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from typing import cast, Dict
77

88
import torch
9+
10+
from captum._utils.models.linear_model.model import LinearModel
911
from torch import Tensor
12+
from torch.utils.data import DataLoader
1013

1114

12-
# pyre-fixme[2]: Parameter must be annotated.
13-
def evaluate(test_data, classifier) -> Dict[str, Tensor]:
15+
def evaluate(test_data: DataLoader, classifier: LinearModel) -> Dict[str, Tensor]:
1416
classifier.eval()
1517

1618
l1_loss = 0.0

tests/metrics/test_sensitivity.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44

55
import typing
6-
from typing import Callable, cast, List, Optional, Tuple, Union
6+
from typing import Callable, List, Optional, Tuple, Union
77

88
import torch
99
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -28,19 +28,15 @@
2828

2929

3030
@typing.overload
31-
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
32-
# arguments of overload defined on line `32`.
3331
def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ...
3432

3533

3634
@typing.overload
37-
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
38-
# arguments of overload defined on line `28`.
3935
def _perturb_func(inputs: Tensor) -> Tensor: ...
4036

4137

4238
def _perturb_func(
43-
inputs: TensorOrTupleOfTensorsGeneric,
39+
inputs: Union[Tensor, Tuple[Tensor, ...]],
4440
) -> Union[Tensor, Tuple[Tensor, ...]]:
4541
def perturb_ratio(input: Tensor) -> Tensor:
4642
return (
@@ -55,7 +51,7 @@ def perturb_ratio(input: Tensor) -> Tensor:
5551
input1 = inputs[0]
5652
input2 = inputs[1]
5753
else:
58-
input1 = cast(Tensor, inputs)
54+
input1 = inputs
5955

6056
perturbed_input1 = input1 + perturb_ratio(input1)
6157

@@ -283,12 +279,13 @@ def test_classification_sensitivity_tpl_target_w_baseline(self) -> None:
283279

284280
def sensitivity_max_assert(
285281
self,
286-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
287-
expl_func: Callable,
282+
expl_func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
288283
inputs: TensorOrTupleOfTensorsGeneric,
289284
expected_sensitivity: Tensor,
290-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
291-
perturb_func: Callable = _perturb_func,
285+
perturb_func: Union[
286+
Callable[[Tensor], Tensor],
287+
Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]],
288+
] = _perturb_func,
292289
n_perturb_samples: int = 5,
293290
max_examples_per_batch: Optional[int] = None,
294291
baselines: Optional[BaselineType] = None,

0 commit comments

Comments
 (0)