22
33# pyre-strict
44import typing
5- from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union
5+ from typing import Any , Callable , cast , Dict , List , Literal , Optional , Tuple , Union
66
77import torch
88from captum ._utils .common import (
@@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution):
4444
4545 def __init__ (
4646 self ,
47- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
48- forward_func : Callable ,
47+ forward_func : Callable [..., Tensor ],
4948 layer : Module ,
5049 device_ids : Union [None , List [int ]] = None ,
5150 ) -> None :
@@ -73,8 +72,7 @@ def has_convergence_delta(self) -> bool:
7372 return True
7473
7574 @typing .overload
76- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
77- # arguments of overload defined on line `75`.
75+ @log_usage ()
7876 def attribute (
7977 self ,
8078 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -91,8 +89,7 @@ def attribute(
9189 ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]: ...
9290
9391 @typing .overload
94- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
95- # arguments of overload defined on line `91`.
92+ @log_usage ()
9693 def attribute (
9794 self ,
9895 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -108,8 +105,6 @@ def attribute(
108105 ) -> Union [Tensor , Tuple [Tensor , ...]]: ...
109106
110107 @log_usage ()
111- # pyre-fixme[43]: This definition does not have the same decorators as the
112- # preceding overload(s).
113108 def attribute (
114109 self ,
115110 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -376,7 +371,7 @@ def _attribute(
376371 layer_evals ,
377372 ) = compute_layer_gradients_and_eval (
378373 forward_fn = self .forward_func ,
379- layer = self .layer ,
374+ layer = cast ( Module , self .layer ) ,
380375 inputs = scaled_features_tpl ,
381376 additional_forward_args = input_additional_args ,
382377 target_ind = expanded_target ,
@@ -389,8 +384,6 @@ def _attribute(
389384 # This approximates the total input gradient of each step multiplied
390385 # by the step size.
391386 grad_diffs = tuple (
392- # pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor,
393- # ...]` and `Tuple[Tensor, ...]`.
394387 layer_eval [num_examples :] - layer_eval [:- num_examples ]
395388 for layer_eval in layer_evals
396389 )
@@ -403,8 +396,7 @@ def _attribute(
403396 grad_diff * layer_gradient [:- num_examples ],
404397 n_steps ,
405398 num_examples ,
406- # pyre-fixme[16]: `tuple` has no attribute `shape`.
407- layer_eval .shape [1 :],
399+ tuple (layer_eval .shape [1 :]),
408400 )
409401 for layer_gradient , layer_eval , grad_diff in zip (
410402 layer_gradients , layer_evals , grad_diffs
0 commit comments