Skip to content

Commit 3856b51

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix gradcam pyre fixme issues (#1466)
Summary: Pull Request resolved: #1466 Differential Revision: D67705191
1 parent 4aee7ce commit 3856b51

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

captum/attr/_core/layer/grad_cam.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55

66
import torch
77
import torch.nn.functional as F
@@ -54,8 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution):
5454

5555
def __init__(
5656
self,
57-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
58-
forward_func: Callable,
57+
forward_func: Callable[..., Tensor],
5958
layer: Module,
6059
device_ids: Union[None, List[int]] = None,
6160
) -> None:
@@ -201,7 +200,7 @@ def attribute(
201200
# hidden layer and hidden layer evaluated at each input.
202201
layer_gradients, layer_evals = compute_layer_gradients_and_eval(
203202
self.forward_func,
204-
self.layer,
203+
cast(Module, self.layer),
205204
inputs,
206205
target,
207206
additional_forward_args,
@@ -213,10 +212,7 @@ def attribute(
213212
summed_grads = tuple(
214213
(
215214
torch.mean(
216-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
217-
# `Tuple[Tensor, ...]`.
218215
layer_grad,
219-
# pyre-fixme[16]: `tuple` has no attribute `shape`.
220216
dim=tuple(x for x in range(2, len(layer_grad.shape))),
221217
keepdim=True,
222218
)
@@ -228,27 +224,15 @@ def attribute(
228224

229225
if attr_dim_summation:
230226
scaled_acts = tuple(
231-
# pyre-fixme[58]: `*` is not supported for operand types
232-
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
233-
# `Tuple[Tensor, ...]`.
234-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
235-
# `Tuple[Tensor, ...]`.
236227
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
237228
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
238229
)
239230
else:
240231
scaled_acts = tuple(
241-
# pyre-fixme[58]: `*` is not supported for operand types
242-
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
243-
# `Tuple[Tensor, ...]`.
244232
summed_grad * layer_eval
245233
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
246234
)
247235

248236
if relu_attributions:
249-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
250-
# `Union[tuple[Tensor], Tensor]`.
251237
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
252-
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
253-
# `Tuple[Union[tuple[Tensor], Tensor], ...]`.
254238
return _format_output(len(scaled_acts) > 1, scaled_acts)

0 commit comments

Comments
 (0)