Skip to content

Commit cf82bb7

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer integrated gradients pyre fixme issues (meta-pytorch#1473)
Summary: Pull Request resolved: meta-pytorch#1473 Differential Revision: D67706224
1 parent 4932420 commit cf82bb7

File tree

1 file changed

+42
-36
lines changed

1 file changed

+42
-36
lines changed

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from captum.log import log_usage
3535
from torch import Tensor
36+
from torch.nn import Module
3637
from torch.nn.parallel.scatter_gather import scatter
3738

3839

@@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
5859

5960
def __init__(
6061
self,
61-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
62-
forward_func: Callable,
62+
forward_func: Callable[..., Tensor],
6363
layer: ModuleOrModuleList,
6464
device_ids: Union[None, List[int]] = None,
6565
multiply_by_inputs: bool = True,
@@ -128,8 +128,7 @@ def _make_gradient_func(
128128
) -> Callable[..., Tuple[Tensor, ...]]:
129129

130130
def _gradient_func(
131-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
132-
forward_fn: Callable,
131+
forward_fn: Callable[..., Tensor],
133132
inputs: Union[Tensor, Tuple[Tensor, ...]],
134133
target_ind: TargetType = None,
135134
additional_forward_args: Optional[object] = None,
@@ -146,28 +145,21 @@ def _gradient_func(
146145
target_gpus=self.device_ids,
147146
)
148147

149-
scattered_inputs_dict = {
148+
scattered_inputs_dict: Dict[
149+
torch.device, Union[Tensor, Tuple[Tensor, ...]]
150+
] = {
150151
scattered_input[0].device: scattered_input
151152
for scattered_input in scattered_inputs
152153
}
153154

154155
with torch.autograd.set_grad_enabled(True):
155156

156-
# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
157-
# annotated.
158-
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
159-
# annotated.
160-
# pyre-fixme[3]: Return type must be annotated.
161157
def layer_forward_hook(
162-
# pyre-fixme[2]: Parameter must be annotated.
163-
module,
164-
# pyre-fixme[2]: Parameter must be annotated.
165-
hook_inputs,
166-
# pyre-fixme[2]: Parameter must be annotated.
167-
hook_outputs=None,
168-
# pyre-fixme[2]: Parameter must be annotated.
169-
layer_idx=0,
170-
):
158+
module: Module,
159+
hook_inputs: Union[Tensor, Tuple[Tensor, ...]],
160+
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
161+
layer_idx: int = 0,
162+
) -> Union[Tensor, Tuple[Tensor, ...]]:
171163
device = _extract_device(module, hook_inputs, hook_outputs)
172164
is_layer_tuple = (
173165
isinstance(hook_outputs, tuple)
@@ -177,11 +169,14 @@ def layer_forward_hook(
177169
)
178170

179171
if is_layer_tuple:
180-
return scattered_inputs_dict[device][
181-
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
182-
layer_idx + 1
183-
]
184-
]
172+
return cast(
173+
Union[Tensor, Tuple[Tensor, ...]],
174+
scattered_inputs_dict[device][
175+
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
176+
layer_idx + 1
177+
]
178+
],
179+
)
185180

186181
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
187182

@@ -255,6 +250,7 @@ def attribute(
255250
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...
256251

257252
@overload
253+
@log_usage()
258254
def attribute( # type: ignore
259255
self,
260256
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -273,8 +269,7 @@ def attribute( # type: ignore
273269
]: ...
274270

275271
@overload
276-
# pyre-fixme[43]: This definition does not have the same decorators as the
277-
# preceding overload(s).
272+
@log_usage()
278273
def attribute(
279274
self,
280275
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -296,8 +291,6 @@ def attribute(
296291
]: ...
297292

298293
@log_usage()
299-
# pyre-fixme[43]: This definition does not have the same decorators as the
300-
# preceding overload(s).
301294
def attribute(
302295
self,
303296
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -502,11 +495,22 @@ def attribute(
502495
additional_forward_args
503496
)
504497

505-
# pyre-fixme[3]: Return type must be annotated.
506-
# pyre-fixme[2]: Parameter must be annotated.
507-
def flatten_tuple(tup):
498+
def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
508499
return tuple(
509-
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
500+
cast(
501+
List[Tensor],
502+
sum(
503+
(
504+
(
505+
list(x)
506+
if isinstance(x, (tuple, list))
507+
else cast(List[Tensor], [x])
508+
)
509+
for x in tup
510+
),
511+
[],
512+
),
513+
)
510514
)
511515

512516
if self.device_ids is None:
@@ -520,16 +524,18 @@ def flatten_tuple(tup):
520524
additional_forward_args=additional_forward_args,
521525
attribute_to_layer_input=attribute_to_layer_input,
522526
)
523-
527+
input_layer_list: List[Tuple[Tensor, ...]]
524528
# if we have one output
525529
if not isinstance(self.layer, list):
526-
inputs_layer = (inputs_layer,)
530+
input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)]
531+
else:
532+
input_layer_list = inputs_layer
527533

528-
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
534+
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list]
529535
num_outputs_cumsum = torch.cumsum(
530536
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
531537
)
532-
inputs_layer = flatten_tuple(inputs_layer)
538+
inputs_layer = flatten_tuple(input_layer_list)
533539

534540
baselines_layer = _forward_layer_eval(
535541
self.forward_func,

0 commit comments

Comments
 (0)