Skip to content

Commit 2db41e7

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron gradient shap pyre fixme issues (#1463)
Summary: Pull Request resolved: #1463 Differential Revision: D67705098
1 parent 6e03bab commit 2db41e7

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

captum/attr/_core/neuron/neuron_gradient_shap.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from typing import Callable, List, Optional, Tuple, Union
55

66
from captum._utils.gradient import construct_neuron_grad_fn
7-
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
7+
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
88
from captum.attr._core.gradient_shap import GradientShap
99
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
1010
from captum.log import log_usage
11+
from torch import Tensor
1112
from torch.nn import Module
1213

1314

@@ -50,8 +51,7 @@ class NeuronGradientShap(NeuronAttribution, GradientAttribution):
5051

5152
def __init__(
5253
self,
53-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
54-
forward_func: Callable,
54+
forward_func: Callable[..., Union[int, float, Tensor]],
5555
layer: Module,
5656
device_ids: Union[None, List[int]] = None,
5757
multiply_by_inputs: bool = True,
@@ -97,8 +97,11 @@ def __init__(
9797
def attribute(
9898
self,
9999
inputs: TensorOrTupleOfTensorsGeneric,
100-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
101-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
100+
neuron_selector: Union[
101+
int,
102+
Tuple[Union[int, SliceIntType], ...],
103+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
104+
],
102105
baselines: Union[
103106
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
104107
],

0 commit comments

Comments
 (0)