@@ -39,8 +39,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3939
4040 def __init__ (
4141 self ,
42- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43- forward_func : Callable ,
42+ forward_func : Callable [..., Tensor ],
4443 layer : Module ,
4544 device_ids : Union [None , List [int ]] = None ,
4645 multiply_by_inputs : bool = True ,
@@ -94,8 +93,11 @@ def __init__(
9493 def attribute (
9594 self ,
9695 inputs : TensorOrTupleOfTensorsGeneric ,
97- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98- neuron_selector : Union [int , Tuple [int , ...], Callable ],
96+ neuron_selector : Union [
97+ int ,
98+ Tuple [Union [int , slice [int , int , int ]], ...],
99+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
100+ ],
99101 baselines : BaselineType = None ,
100102 target : TargetType = None ,
101103 additional_forward_args : Optional [object ] = None ,
@@ -285,8 +287,6 @@ def attribute(
285287 " results." ,
286288 stacklevel = 1 ,
287289 )
288- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
289- # `TensorOrTupleOfTensorsGeneric`.
290290 is_inputs_tuple = _is_tuple (inputs )
291291
292292 # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
@@ -334,8 +334,11 @@ def attribute(
334334 def _attribute (
335335 self ,
336336 inputs : Tuple [Tensor , ...],
337- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338- neuron_selector : Union [int , Tuple [int , ...], Callable ],
337+ neuron_selector : Union [
338+ int ,
339+ Tuple [Union [int , slice [int , int , int ]], ...],
340+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
341+ ],
339342 baselines : Tuple [Union [Tensor , int , float ], ...],
340343 target : TargetType = None ,
341344 additional_forward_args : Optional [object ] = None ,
0 commit comments