Skip to content

Commit 6adb36a

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix internal influence pyre fixme issues (#1467)
Summary: Pull Request resolved: #1467 Differential Revision: D67705214
1 parent 3856b51 commit 6adb36a

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

captum/attr/_core/layer/internal_influence.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.common import (
@@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution):
4141

4242
def __init__(
4343
self,
44-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
45-
forward_func: Callable,
44+
forward_func: Callable[..., Tensor],
4645
layer: Module,
4746
device_ids: Union[None, List[int]] = None,
4847
) -> None:
@@ -293,7 +292,7 @@ def _attribute(
293292
# Returns gradient of output with respect to hidden layer.
294293
layer_gradients, _ = compute_layer_gradients_and_eval(
295294
forward_fn=self.forward_func,
296-
layer=self.layer,
295+
layer=cast(Module, self.layer),
297296
inputs=scaled_features_tpl,
298297
target_ind=expanded_target,
299298
additional_forward_args=input_additional_args,
@@ -304,9 +303,7 @@ def _attribute(
304303
# flattening grads so that we can multiply it with step-size
305304
# calling contiguous to avoid `memory whole` problems
306305
scaled_grads = tuple(
307-
# pyre-fixme[16]: `tuple` has no attribute `contiguous`.
308306
layer_grad.contiguous().view(n_steps, -1)
309-
# pyre-fixme[16]: `tuple` has no attribute `device`.
310307
* torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device)
311308
for layer_grad in layer_gradients
312309
)
@@ -317,8 +314,7 @@ def _attribute(
317314
scaled_grad,
318315
n_steps,
319316
inputs[0].shape[0],
320-
# pyre-fixme[16]: `tuple` has no attribute `shape`.
321-
layer_grad.shape[1:],
317+
tuple(layer_grad.shape[1:]),
322318
)
323319
for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients)
324320
)

0 commit comments

Comments
 (0)