1414 _verify_select_neuron ,
1515)
1616from captum ._utils .gradient import compute_layer_gradients_and_eval
17- from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
17+ from captum ._utils .typing import (
18+ BaselineType ,
19+ SliceIntType ,
20+ TargetType ,
21+ TensorOrTupleOfTensorsGeneric ,
22+ )
1823from captum .attr ._utils .approximation_methods import approximation_parameters
1924from captum .attr ._utils .attribution import GradientAttribution , NeuronAttribution
2025from captum .attr ._utils .batching import _batch_attribution
@@ -39,8 +44,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3944
4045 def __init__ (
4146 self ,
42- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43- forward_func : Callable ,
47+ forward_func : Callable [..., Tensor ],
4448 layer : Module ,
4549 device_ids : Union [None , List [int ]] = None ,
4650 multiply_by_inputs : bool = True ,
@@ -94,8 +98,11 @@ def __init__(
9498 def attribute (
9599 self ,
96100 inputs : TensorOrTupleOfTensorsGeneric ,
97- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98- neuron_selector : Union [int , Tuple [int , ...], Callable ],
101+ neuron_selector : Union [
102+ int ,
103+ Tuple [Union [int , SliceIntType ], ...],
104+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
105+ ],
99106 baselines : BaselineType = None ,
100107 target : TargetType = None ,
101108 additional_forward_args : Optional [object ] = None ,
@@ -285,28 +292,24 @@ def attribute(
285292 " results." ,
286293 stacklevel = 1 ,
287294 )
288- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
289- # `TensorOrTupleOfTensorsGeneric`.
290295 is_inputs_tuple = _is_tuple (inputs )
291296
292- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
293- # `Tuple[Tensor, ...]`.
294- inputs , baselines = _format_input_baseline (inputs , baselines )
295- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
296- # `TensorOrTupleOfTensorsGeneric`.
297- _validate_input (inputs , baselines , n_steps , method )
297+ formatted_inputs , formatted_baselines = _format_input_baseline (
298+ inputs , baselines
299+ )
300+ _validate_input (formatted_inputs , formatted_baselines , n_steps , method )
298301
299- num_examples = inputs [0 ].shape [0 ]
302+ num_examples = formatted_inputs [0 ].shape [0 ]
300303
301304 if internal_batch_size is not None :
302- num_examples = inputs [0 ].shape [0 ]
305+ num_examples = formatted_inputs [0 ].shape [0 ]
303306 attrs = _batch_attribution (
304307 self ,
305308 num_examples ,
306309 internal_batch_size ,
307310 n_steps ,
308- inputs = inputs ,
309- baselines = baselines ,
311+ inputs = formatted_inputs ,
312+ baselines = formatted_baselines ,
310313 neuron_selector = neuron_selector ,
311314 target = target ,
312315 additional_forward_args = additional_forward_args ,
@@ -315,11 +318,9 @@ def attribute(
315318 )
316319 else :
317320 attrs = self ._attribute (
318- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
319- # got `TensorOrTupleOfTensorsGeneric`.
320- inputs = inputs ,
321+ inputs = formatted_inputs ,
321322 neuron_selector = neuron_selector ,
322- baselines = baselines ,
323+ baselines = formatted_baselines ,
323324 target = target ,
324325 additional_forward_args = additional_forward_args ,
325326 n_steps = n_steps ,
@@ -334,8 +335,11 @@ def attribute(
334335 def _attribute (
335336 self ,
336337 inputs : Tuple [Tensor , ...],
337- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338- neuron_selector : Union [int , Tuple [int , ...], Callable ],
338+ neuron_selector : Union [
339+ int ,
340+ Tuple [Union [int , SliceIntType ], ...],
341+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
342+ ],
339343 baselines : Tuple [Union [Tensor , int , float ], ...],
340344 target : TargetType = None ,
341345 additional_forward_args : Optional [object ] = None ,
@@ -409,8 +413,9 @@ def _attribute(
409413
410414 # Aggregates across all steps for each tensor in the input tuple
411415 total_grads = tuple (
412- # pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
413- _reshape_and_sum (scaled_grad , n_steps , num_examples , input_grad .shape [1 :])
416+ _reshape_and_sum (
417+ scaled_grad , n_steps , num_examples , tuple (input_grad .shape [1 :])
418+ )
414419 for (scaled_grad , input_grad ) in zip (scaled_grads , input_grads )
415420 )
416421
0 commit comments