Skip to content

Commit e4e23f1

Browse files
Ayush-Warikoofacebook-github-bot
authored andcommitted
Address Pyre FixMe's in layer_feature_permutation.py (#1409)
Summary: This diff helps address the number of pyre-fixme's in the layer_feature_permutation.py file Reviewed By: jjuncho Differential Revision: D64624875
1 parent e63d39f commit e4e23f1

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

captum/attr/_core/layer/layer_feature_permutation.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, cast, List, Tuple, Type, Union
4+
from typing import Any, Callable, cast, Dict, List, Tuple, Type, Union
55

66
import torch
77
from captum._utils.common import (
@@ -18,7 +18,7 @@
1818
from captum.attr._core.feature_permutation import FeaturePermutation
1919
from captum.attr._utils.attribution import LayerAttribution
2020
from captum.log import log_usage
21-
from torch import Tensor
21+
from torch import device, Tensor
2222
from torch.nn import Module
2323
from torch.nn.parallel.scatter_gather import scatter
2424

@@ -32,8 +32,7 @@ class LayerFeaturePermutation(LayerAttribution, FeaturePermutation):
3232

3333
def __init__(
3434
self,
35-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
36-
forward_func: Callable,
35+
forward_func: Callable[..., Tensor],
3736
layer: Module,
3837
device_ids: Union[None, List[int]] = None,
3938
) -> None:
@@ -64,8 +63,7 @@ def attribute(
6463
self,
6564
inputs: Union[Tensor, Tuple[Tensor, ...]],
6665
target: TargetType = None,
67-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
68-
additional_forward_args: Any = None,
66+
additional_forward_args: object = None,
6967
layer_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
7068
perturbations_per_eval: int = 1,
7169
) -> Union[Tensor, Tuple[Tensor, ...]]:
@@ -159,28 +157,33 @@ def attribute(
159157
otherwise a single tensor is returned.
160158
"""
161159

162-
# pyre-fixme[2]: Parameter must be annotated.
163-
def layer_forward_func(*args) -> Tensor:
164-
layer_length = args[-1]
165-
layer_input = args[:layer_length]
166-
original_inputs = args[layer_length:-1]
160+
def layer_forward_func(*args: Any) -> Tensor:
161+
r"""
162+
Args:
163+
args (Any): Tensors comprising the layer input and the original
164+
inputs, and an int representing the length of the layer input
165+
"""
166+
layer_length: int = args[-1]
167+
layer_input: Tuple[Tensor, ...] = args[:layer_length]
168+
original_inputs: Tuple[Tensor, ...] = args[layer_length:-1]
167169

168170
device_ids = self.device_ids
169171
if device_ids is None:
170172
device_ids = getattr(self.forward_func, "device_ids", None)
171173

172-
all_layer_inputs = {}
174+
all_layer_inputs: Dict[device, Tuple[Tensor, ...]] = {}
173175
if device_ids is not None:
174176
scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
175177
for device_tensors in scattered_layer_input:
176178
all_layer_inputs[device_tensors[0].device] = device_tensors
177179
else:
178180
all_layer_inputs[layer_input[0].device] = layer_input
179181

180-
# pyre-fixme[53]: Captured variable `all_layer_inputs` is not annotated.
181-
# pyre-fixme[3]: Return type must be annotated.
182-
# pyre-fixme[2]: Parameter must be annotated.
183-
def forward_hook(module, inp, out=None):
182+
def forward_hook(
183+
module: Module,
184+
inp: Union[None, Tensor, Tuple[Tensor, ...]],
185+
out: Union[None, Tensor, Tuple[Tensor, ...]] = None,
186+
) -> Union[Tensor, Tuple[Tensor, ...]]:
184187
device = _extract_device(module, inp, out)
185188
is_layer_tuple = (
186189
isinstance(out, tuple)

0 commit comments

Comments
 (0)