Skip to content

Commit 56961b0

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron conductance pyre fixme issues (#1458)
Summary: Pull Request resolved: #1458 Differential Revision: D67523217
1 parent 1d437c3 commit 56961b0

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

captum/attr/_core/neuron/neuron_conductance.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3939

4040
def __init__(
4141
self,
42-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43-
forward_func: Callable,
42+
forward_func: Callable[..., Tensor],
4443
layer: Module,
4544
device_ids: Union[None, List[int]] = None,
4645
multiply_by_inputs: bool = True,
@@ -94,8 +93,11 @@ def __init__(
9493
def attribute(
9594
self,
9695
inputs: TensorOrTupleOfTensorsGeneric,
97-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98-
neuron_selector: Union[int, Tuple[int, ...], Callable],
96+
neuron_selector: Union[
97+
int,
98+
Tuple[Union[int, slice[int, int, int]], ...],
99+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
100+
],
99101
baselines: BaselineType = None,
100102
target: TargetType = None,
101103
additional_forward_args: Optional[object] = None,
@@ -285,8 +287,6 @@ def attribute(
285287
" results.",
286288
stacklevel=1,
287289
)
288-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
289-
# `TensorOrTupleOfTensorsGeneric`.
290290
is_inputs_tuple = _is_tuple(inputs)
291291

292292
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
@@ -334,8 +334,11 @@ def attribute(
334334
def _attribute(
335335
self,
336336
inputs: Tuple[Tensor, ...],
337-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338-
neuron_selector: Union[int, Tuple[int, ...], Callable],
337+
neuron_selector: Union[
338+
int,
339+
Tuple[Union[int, slice[int, int, int]], ...],
340+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
341+
],
339342
baselines: Tuple[Union[Tensor, int, float], ...],
340343
target: TargetType = None,
341344
additional_forward_args: Optional[object] = None,

0 commit comments

Comments
 (0)