Skip to content

Commit 23967e9

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer integrated gradients pyre fixme issues (#1473)
Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D67706224
1 parent a95eb46 commit 23967e9

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from captum.log import log_usage
3535
from torch import Tensor
36+
from torch.nn import Module
3637
from torch.nn.parallel.scatter_gather import scatter
3738

3839

@@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
5859

5960
def __init__(
6061
self,
61-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
62-
forward_func: Callable,
62+
forward_func: Callable[..., Tensor],
6363
layer: ModuleOrModuleList,
6464
device_ids: Union[None, List[int]] = None,
6565
multiply_by_inputs: bool = True,
@@ -128,8 +128,7 @@ def _make_gradient_func(
128128
) -> Callable[..., Tuple[Tensor, ...]]:
129129

130130
def _gradient_func(
131-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
132-
forward_fn: Callable,
131+
forward_fn: Callable[..., Tensor],
133132
inputs: Union[Tensor, Tuple[Tensor, ...]],
134133
target_ind: TargetType = None,
135134
additional_forward_args: Optional[object] = None,
@@ -146,28 +145,21 @@ def _gradient_func(
146145
target_gpus=self.device_ids,
147146
)
148147

149-
scattered_inputs_dict = {
148+
scattered_inputs_dict: Dict[
149+
torch.device, Union[Tensor, Tuple[Tensor, ...]]
150+
] = {
150151
scattered_input[0].device: scattered_input
151152
for scattered_input in scattered_inputs
152153
}
153154

154155
with torch.autograd.set_grad_enabled(True):
155156

156-
# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
157-
# annotated.
158-
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
159-
# annotated.
160-
# pyre-fixme[3]: Return type must be annotated.
161157
def layer_forward_hook(
162-
# pyre-fixme[2]: Parameter must be annotated.
163-
module,
164-
# pyre-fixme[2]: Parameter must be annotated.
165-
hook_inputs,
166-
# pyre-fixme[2]: Parameter must be annotated.
167-
hook_outputs=None,
168-
# pyre-fixme[2]: Parameter must be annotated.
169-
layer_idx=0,
170-
):
158+
module: Module,
159+
hook_inputs: Union[Tensor, Tuple[Tensor, ...]],
160+
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
161+
layer_idx: int = 0,
162+
) -> Union[Tensor, Tuple[Tensor, ...]]:
171163
device = _extract_device(module, hook_inputs, hook_outputs)
172164
is_layer_tuple = (
173165
isinstance(hook_outputs, tuple)
@@ -177,11 +169,14 @@ def layer_forward_hook(
177169
)
178170

179171
if is_layer_tuple:
180-
return scattered_inputs_dict[device][
181-
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
182-
layer_idx + 1
183-
]
184-
]
172+
return cast(
173+
Union[Tensor, Tuple[Tensor, ...]],
174+
scattered_inputs_dict[device][
175+
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
176+
layer_idx + 1
177+
]
178+
],
179+
)
185180

186181
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
187182

@@ -502,11 +497,22 @@ def attribute(
502497
additional_forward_args
503498
)
504499

505-
# pyre-fixme[3]: Return type must be annotated.
506-
# pyre-fixme[2]: Parameter must be annotated.
507-
def flatten_tuple(tup):
500+
def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
508501
return tuple(
509-
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
502+
cast(
503+
List[Tensor],
504+
sum(
505+
(
506+
(
507+
list(x)
508+
if isinstance(x, (tuple, list))
509+
else cast(List[Tensor], [x])
510+
)
511+
for x in tup
512+
),
513+
[],
514+
),
515+
)
510516
)
511517

512518
if self.device_ids is None:
@@ -520,16 +526,18 @@ def flatten_tuple(tup):
520526
additional_forward_args=additional_forward_args,
521527
attribute_to_layer_input=attribute_to_layer_input,
522528
)
523-
529+
input_layer_list: List[Tuple[Tensor, ...]]
524530
# if we have one output
525531
if not isinstance(self.layer, list):
526-
inputs_layer = (inputs_layer,)
532+
input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)]
533+
else:
534+
input_layer_list = inputs_layer
527535

528-
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
536+
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list]
529537
num_outputs_cumsum = torch.cumsum(
530538
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
531539
)
532-
inputs_layer = flatten_tuple(inputs_layer)
540+
inputs_layer = flatten_tuple(input_layer_list)
533541

534542
baselines_layer = _forward_layer_eval(
535543
self.forward_func,

0 commit comments

Comments
 (0)