66import torch
77from captum ._utils .common import _verify_select_neuron
88from captum ._utils .gradient import _forward_layer_eval
9- from captum ._utils .typing import BaselineType , TensorOrTupleOfTensorsGeneric
9+ from captum ._utils .typing import (
10+ BaselineType ,
11+ SliceIntType ,
12+ TensorOrTupleOfTensorsGeneric ,
13+ )
1014from captum .attr ._core .feature_ablation import FeatureAblation
1115from captum .attr ._utils .attribution import NeuronAttribution , PerturbationAttribution
1216from captum .log import log_usage
@@ -31,8 +35,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution):
3135
3236 def __init__ (
3337 self ,
34- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
35- forward_func : Callable ,
38+ forward_func : Callable [..., Union [int , float , Tensor ]],
3639 layer : Module ,
3740 device_ids : Union [None , List [int ]] = None ,
3841 ) -> None :
@@ -61,8 +64,11 @@ def __init__(
6164 def attribute (
6265 self ,
6366 inputs : TensorOrTupleOfTensorsGeneric ,
64- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
65- neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
67+ neuron_selector : Union [
68+ int ,
69+ Tuple [Union [int , SliceIntType ], ...],
70+ Callable [[Union [Tensor , Tuple [Tensor , ...]]], Tensor ],
71+ ],
6672 baselines : BaselineType = None ,
6773 additional_forward_args : Optional [object ] = None ,
6874 feature_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
@@ -250,8 +256,7 @@ def attribute(
250256 >>> feature_mask=feature_mask)
251257 """
252258
253- # pyre-fixme[3]: Return type must be annotated.
254- def neuron_forward_func (* args : Any ):
259+ def neuron_forward_func (* args : Any ) -> Tensor :
255260 with torch .no_grad ():
256261 layer_eval = _forward_layer_eval (
257262 self .forward_func ,
0 commit comments