44from typing import Callable , List , Optional , Tuple , Union
55
66from captum ._utils .gradient import construct_neuron_grad_fn
7- from captum ._utils .typing import TensorOrTupleOfTensorsGeneric
7+ from captum ._utils .typing import SliceIntType , TensorOrTupleOfTensorsGeneric
88from captum .attr ._core .integrated_gradients import IntegratedGradients
99from captum .attr ._utils .attribution import GradientAttribution , NeuronAttribution
1010from captum .log import log_usage
@@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):
2727
2828 def __init__ (
2929 self ,
30- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31- forward_func : Callable ,
30+ forward_func : Callable [..., Tensor ],
3231 layer : Module ,
3332 device_ids : Union [None , List [int ]] = None ,
3433 multiply_by_inputs : bool = True ,
@@ -76,8 +75,11 @@ def __init__(
7675 def attribute (
7776 self ,
7877 inputs : TensorOrTupleOfTensorsGeneric ,
79- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
80- neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
78+ neuron_selector : Union [
79+ int ,
80+ Tuple [Union [int , SliceIntType ], ...],
81+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
82+ ],
8183 baselines : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
8284 additional_forward_args : Optional [object ] = None ,
8385 n_steps : int = 50 ,
0 commit comments