Skip to content

Commit 80f0832

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Fix layer conductance pyre fixme issues
Differential Revision: D67705320
1 parent 982f35b commit 80f0832

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

captum/attr/_core/layer/layer_conductance.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
5+
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution):
4444

4545
def __init__(
4646
self,
47-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
48-
forward_func: Callable,
47+
forward_func: Callable[..., Tensor],
4948
layer: Module,
5049
device_ids: Union[None, List[int]] = None,
5150
) -> None:
@@ -73,8 +72,7 @@ def has_convergence_delta(self) -> bool:
7372
return True
7473

7574
@typing.overload
76-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
77-
# arguments of overload defined on line `75`.
75+
@log_usage()
7876
def attribute(
7977
self,
8078
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -91,8 +89,7 @@ def attribute(
9189
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
9290

9391
@typing.overload
94-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
95-
# arguments of overload defined on line `91`.
92+
@log_usage()
9693
def attribute(
9794
self,
9895
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -108,8 +105,6 @@ def attribute(
108105
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
109106

110107
@log_usage()
111-
# pyre-fixme[43]: This definition does not have the same decorators as the
112-
# preceding overload(s).
113108
def attribute(
114109
self,
115110
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -376,7 +371,7 @@ def _attribute(
376371
layer_evals,
377372
) = compute_layer_gradients_and_eval(
378373
forward_fn=self.forward_func,
379-
layer=self.layer,
374+
layer=cast(Module, self.layer),
380375
inputs=scaled_features_tpl,
381376
additional_forward_args=input_additional_args,
382377
target_ind=expanded_target,
@@ -389,8 +384,6 @@ def _attribute(
389384
# This approximates the total input gradient of each step multiplied
390385
# by the step size.
391386
grad_diffs = tuple(
392-
# pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor,
393-
# ...]` and `Tuple[Tensor, ...]`.
394387
layer_eval[num_examples:] - layer_eval[:-num_examples]
395388
for layer_eval in layer_evals
396389
)
@@ -403,8 +396,7 @@ def _attribute(
403396
grad_diff * layer_gradient[:-num_examples],
404397
n_steps,
405398
num_examples,
406-
# pyre-fixme[16]: `tuple` has no attribute `shape`.
407-
layer_eval.shape[1:],
399+
tuple(layer_eval.shape[1:]),
408400
)
409401
for layer_gradient, layer_eval, grad_diff in zip(
410402
layer_gradients, layer_evals, grad_diffs

0 commit comments

Comments
 (0)