|
4 | 4 | from typing import Callable, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | from captum._utils.gradient import construct_neuron_grad_fn |
7 | | -from captum._utils.typing import TensorOrTupleOfTensorsGeneric |
| 7 | +from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric |
8 | 8 | from captum.attr._core.gradient_shap import GradientShap |
9 | 9 | from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution |
10 | 10 | from captum.log import log_usage |
| 11 | +from torch import Tensor |
11 | 12 | from torch.nn import Module |
12 | 13 |
|
13 | 14 |
|
@@ -50,8 +51,7 @@ class NeuronGradientShap(NeuronAttribution, GradientAttribution): |
50 | 51 |
|
51 | 52 | def __init__( |
52 | 53 | self, |
53 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
54 | | - forward_func: Callable, |
| 54 | + forward_func: Callable[..., Union[int, float, Tensor]], |
55 | 55 | layer: Module, |
56 | 56 | device_ids: Union[None, List[int]] = None, |
57 | 57 | multiply_by_inputs: bool = True, |
@@ -97,8 +97,11 @@ def __init__( |
97 | 97 | def attribute( |
98 | 98 | self, |
99 | 99 | inputs: TensorOrTupleOfTensorsGeneric, |
100 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
101 | | - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], |
| 100 | + neuron_selector: Union[ |
| 101 | + int, |
| 102 | + Tuple[Union[int, SliceIntType], ...], |
| 103 | + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], |
| 104 | + ], |
102 | 105 | baselines: Union[ |
103 | 106 | TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] |
104 | 107 | ], |
|
0 commit comments