Skip to content

Commit 7bb3f33

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron guided backprop pyre fixme issues (#1465)
Summary: Pull Request resolved: #1465 Differential Revision: D67705091
1 parent e38a3a5 commit 7bb3f33

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from typing import Callable, List, Optional, Tuple, Union
55

66
from captum._utils.gradient import construct_neuron_grad_fn
7-
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
7+
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
88
from captum.attr._core.guided_backprop_deconvnet import Deconvolution, GuidedBackprop
99
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
1010
from captum.log import log_usage
11+
from torch import Tensor
1112
from torch.nn import Module
1213

1314

@@ -60,8 +61,11 @@ def __init__(
6061
def attribute(
6162
self,
6263
inputs: TensorOrTupleOfTensorsGeneric,
63-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
64-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
64+
neuron_selector: Union[
65+
int,
66+
Tuple[Union[int, SliceIntType], ...],
67+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
68+
],
6569
additional_forward_args: Optional[object] = None,
6670
attribute_to_neuron_input: bool = False,
6771
) -> TensorOrTupleOfTensorsGeneric:
@@ -215,8 +219,11 @@ def __init__(
215219
def attribute(
216220
self,
217221
inputs: TensorOrTupleOfTensorsGeneric,
218-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
219-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
222+
neuron_selector: Union[
223+
int,
224+
Tuple[Union[int, SliceIntType], ...],
225+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
226+
],
220227
additional_forward_args: Optional[object] = None,
221228
attribute_to_neuron_input: bool = False,
222229
) -> TensorOrTupleOfTensorsGeneric:

0 commit comments

Comments
 (0)