1414 apply_gradient_requirements ,
1515 undo_gradient_requirements ,
1616)
17- from captum ._utils .typing import TensorOrTupleOfTensorsGeneric
17+ from captum ._utils .typing import SliceIntType , TensorOrTupleOfTensorsGeneric
1818from captum .attr ._utils .attribution import GradientAttribution , NeuronAttribution
1919from captum .log import log_usage
20+ from torch import Tensor
2021from torch .nn import Module
2122
2223
@@ -28,8 +29,7 @@ class NeuronGradient(NeuronAttribution, GradientAttribution):
2829
2930 def __init__ (
3031 self ,
31- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
32- forward_func : Callable ,
32+ forward_func : Callable [..., Union [int , float , Tensor ]],
3333 layer : Module ,
3434 device_ids : Union [None , List [int ]] = None ,
3535 ) -> None :
@@ -60,8 +60,11 @@ def __init__(
6060 def attribute (
6161 self ,
6262 inputs : TensorOrTupleOfTensorsGeneric ,
63- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
64- neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
63+ neuron_selector : Union [
64+ int ,
65+ Tuple [Union [int , SliceIntType ], ...],
66+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
67+ ],
6568 additional_forward_args : Optional [object ] = None ,
6669 attribute_to_neuron_input : bool = False ,
6770 ) -> TensorOrTupleOfTensorsGeneric :
@@ -162,18 +165,12 @@ def attribute(
162165 >>> # index (4,1,2).
163166 >>> attribution = neuron_ig.attribute(input, (4,1,2))
164167 """
165- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
166- # `TensorOrTupleOfTensorsGeneric`.
167168 is_inputs_tuple = _is_tuple (inputs )
168- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
169- # `Tuple[Tensor, ...]`.
170- inputs = _format_tensor_into_tuples (inputs )
169+ inputs_tuple = _format_tensor_into_tuples (inputs )
171170 additional_forward_args = _format_additional_forward_args (
172171 additional_forward_args
173172 )
174- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
175- # `TensorOrTupleOfTensorsGeneric`.
176- gradient_mask = apply_gradient_requirements (inputs )
173+ gradient_mask = apply_gradient_requirements (inputs_tuple )
177174
178175 _ , input_grads = _forward_layer_eval_with_neuron_grads (
179176 self .forward_func ,
@@ -185,9 +182,9 @@ def attribute(
185182 attribute_to_layer_input = attribute_to_neuron_input ,
186183 )
187184
188- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
189- # `TensorOrTupleOfTensorsGeneric`.
190- undo_gradient_requirements ( inputs , gradient_mask )
191- # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
192- # ` Tuple[Tensor, ...]`.
185+ undo_gradient_requirements ( inputs_tuple , gradient_mask )
186+
187+ # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
188+ # [Tensor, typing.Tuple[Tensor, ...]]]` but got `Union[Tensor,
189+ # typing. Tuple[Tensor, ...] ]`.
193190 return _format_output (is_inputs_tuple , input_grads )
0 commit comments