Skip to content

Commit 57d1f93

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron gradient pyre fixme issues (#1464)
Summary: Pull Request resolved: #1464 Fixing unresolved pyre fixme issues in corresponding file Reviewed By: craymichael Differential Revision: D67704365 fbshipit-source-id: f9d210e4f7afb8287bfcd9c8edd9e2a998b0a50b
1 parent 27c2004 commit 57d1f93

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

captum/attr/_core/neuron/neuron_gradient.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
apply_gradient_requirements,
1515
undo_gradient_requirements,
1616
)
17-
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
17+
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
1818
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
1919
from captum.log import log_usage
20+
from torch import Tensor
2021
from torch.nn import Module
2122

2223

@@ -28,8 +29,7 @@ class NeuronGradient(NeuronAttribution, GradientAttribution):
2829

2930
def __init__(
3031
self,
31-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
32-
forward_func: Callable,
32+
forward_func: Callable[..., Union[int, float, Tensor]],
3333
layer: Module,
3434
device_ids: Union[None, List[int]] = None,
3535
) -> None:
@@ -60,8 +60,11 @@ def __init__(
6060
def attribute(
6161
self,
6262
inputs: TensorOrTupleOfTensorsGeneric,
63-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
64-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
63+
neuron_selector: Union[
64+
int,
65+
Tuple[Union[int, SliceIntType], ...],
66+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
67+
],
6568
additional_forward_args: Optional[object] = None,
6669
attribute_to_neuron_input: bool = False,
6770
) -> TensorOrTupleOfTensorsGeneric:
@@ -162,18 +165,12 @@ def attribute(
162165
>>> # index (4,1,2).
163166
>>> attribution = neuron_ig.attribute(input, (4,1,2))
164167
"""
165-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
166-
# `TensorOrTupleOfTensorsGeneric`.
167168
is_inputs_tuple = _is_tuple(inputs)
168-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
169-
# `Tuple[Tensor, ...]`.
170-
inputs = _format_tensor_into_tuples(inputs)
169+
inputs_tuple = _format_tensor_into_tuples(inputs)
171170
additional_forward_args = _format_additional_forward_args(
172171
additional_forward_args
173172
)
174-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
175-
# `TensorOrTupleOfTensorsGeneric`.
176-
gradient_mask = apply_gradient_requirements(inputs)
173+
gradient_mask = apply_gradient_requirements(inputs_tuple)
177174

178175
_, input_grads = _forward_layer_eval_with_neuron_grads(
179176
self.forward_func,
@@ -185,9 +182,9 @@ def attribute(
185182
attribute_to_layer_input=attribute_to_neuron_input,
186183
)
187184

188-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
189-
# `TensorOrTupleOfTensorsGeneric`.
190-
undo_gradient_requirements(inputs, gradient_mask)
191-
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
192-
# `Tuple[Tensor, ...]`.
185+
undo_gradient_requirements(inputs_tuple, gradient_mask)
186+
187+
# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
188+
# [Tensor, typing.Tuple[Tensor, ...]]]` but got `Union[Tensor,
189+
# typing.Tuple[Tensor, ...]]`.
193190
return _format_output(is_inputs_tuple, input_grads)

0 commit comments

Comments
 (0)