|
4 | 4 | from typing import Callable, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | from captum._utils.gradient import construct_neuron_grad_fn |
7 | | -from captum._utils.typing import TensorOrTupleOfTensorsGeneric |
| 7 | +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric |
8 | 8 | from captum.attr._core.guided_backprop_deconvnet import Deconvolution, GuidedBackprop |
9 | 9 | from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution |
10 | 10 | from captum.log import log_usage |
| 11 | +from torch import Tensor |
11 | 12 | from torch.nn import Module |
12 | 13 |
|
13 | 14 |
|
@@ -60,8 +61,11 @@ def __init__( |
60 | 61 | def attribute( |
61 | 62 | self, |
62 | 63 | inputs: TensorOrTupleOfTensorsGeneric, |
63 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
64 | | - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], |
| 64 | + neuron_selector: Union[ |
| 65 | + int, |
| 66 | + Tuple[Union[int, SliceIntType], ...], |
| 67 | + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], |
| 68 | + ], |
65 | 69 | additional_forward_args: Optional[object] = None, |
66 | 70 | attribute_to_neuron_input: bool = False, |
67 | 71 | ) -> TensorOrTupleOfTensorsGeneric: |
@@ -215,8 +219,11 @@ def __init__( |
215 | 219 | def attribute( |
216 | 220 | self, |
217 | 221 | inputs: TensorOrTupleOfTensorsGeneric, |
218 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
219 | | - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], |
| 222 | + neuron_selector: Union[ |
| 223 | + int, |
| 224 | + Tuple[Union[int, SliceIntType], ...], |
| 225 | + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], |
| 226 | + ], |
220 | 227 | additional_forward_args: Optional[object] = None, |
221 | 228 | attribute_to_neuron_input: bool = False, |
222 | 229 | ) -> TensorOrTupleOfTensorsGeneric: |
|
0 commit comments