We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 85b769d commit 012456aCopy full SHA for 012456a
captum/attr/_core/neuron/neuron_integrated_gradients.py
@@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):
27
28
def __init__(
29
self,
30
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31
- forward_func: Callable,
+ forward_func: Callable[..., Tensor],
32
layer: Module,
33
device_ids: Union[None, List[int]] = None,
34
multiply_by_inputs: bool = True,
@@ -76,8 +75,11 @@ def __init__(
76
75
def attribute(
77
78
inputs: TensorOrTupleOfTensorsGeneric,
79
80
- neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
+ neuron_selector: Union[
+ int,
+ Tuple[Union[int, slice[int, int, int]], ...],
81
+ Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
82
+ ],
83
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
84
additional_forward_args: Optional[object] = None,
85
n_steps: int = 50,
0 commit comments