@@ -39,8 +39,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3939
4040 def __init__ (
4141 self ,
42- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43- forward_func : Callable ,
42+ forward_func : Callable [..., Tensor ],
4443 layer : Module ,
4544 device_ids : Union [None , List [int ]] = None ,
4645 multiply_by_inputs : bool = True ,
@@ -94,8 +93,9 @@ def __init__(
9493 def attribute (
9594 self ,
9695 inputs : TensorOrTupleOfTensorsGeneric ,
97- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98- neuron_selector : Union [int , Tuple [int , ...], Callable ],
96+ neuron_selector : Union [
97+ int , Tuple [int , ...], Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ]
98+ ],
9999 baselines : BaselineType = None ,
100100 target : TargetType = None ,
101101 additional_forward_args : Optional [object ] = None ,
@@ -285,8 +285,6 @@ def attribute(
285285 " results." ,
286286 stacklevel = 1 ,
287287 )
288- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
289- # `TensorOrTupleOfTensorsGeneric`.
290288 is_inputs_tuple = _is_tuple (inputs )
291289
292290 # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
@@ -334,8 +332,9 @@ def attribute(
334332 def _attribute (
335333 self ,
336334 inputs : Tuple [Tensor , ...],
337- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338- neuron_selector : Union [int , Tuple [int , ...], Callable ],
335+ neuron_selector : Union [
336+ int , Tuple [int , ...], Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ]
337+ ],
339338 baselines : Tuple [Union [Tensor , int , float ], ...],
340339 target : TargetType = None ,
341340 additional_forward_args : Optional [object ] = None ,
0 commit comments