Skip to content

Commit 012456a

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron_integrated_gradients pyre fixme issues (#1457)
Summary: Pull Request resolved: #1457 Differential Revision: D67523072
1 parent 85b769d commit 012456a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

captum/attr/_core/neuron/neuron_integrated_gradients.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):
2727

2828
def __init__(
2929
self,
30-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31-
forward_func: Callable,
30+
forward_func: Callable[..., Tensor],
3231
layer: Module,
3332
device_ids: Union[None, List[int]] = None,
3433
multiply_by_inputs: bool = True,
@@ -76,8 +75,11 @@ def __init__(
7675
def attribute(
7776
self,
7877
inputs: TensorOrTupleOfTensorsGeneric,
79-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
80-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
78+
neuron_selector: Union[
79+
int,
80+
Tuple[Union[int, slice[int, int, int]], ...],
81+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
82+
],
8183
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
8284
additional_forward_args: Optional[object] = None,
8385
n_steps: int = 50,

0 commit comments

Comments
 (0)