Skip to content

Commit 8a4ea82

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron conductance pyre fixme issues
Differential Revision: D67523217
1 parent 26e56f7 commit 8a4ea82

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

captum/attr/_core/neuron/neuron_conductance.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3939

4040
def __init__(
4141
self,
42-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43-
forward_func: Callable,
42+
forward_func: Callable[..., Tensor],
4443
layer: Module,
4544
device_ids: Union[None, List[int]] = None,
4645
multiply_by_inputs: bool = True,
@@ -94,8 +93,9 @@ def __init__(
9493
def attribute(
9594
self,
9695
inputs: TensorOrTupleOfTensorsGeneric,
97-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98-
neuron_selector: Union[int, Tuple[int, ...], Callable],
96+
neuron_selector: Union[
97+
int, Tuple[int, ...], Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor]
98+
],
9999
baselines: BaselineType = None,
100100
target: TargetType = None,
101101
additional_forward_args: Optional[object] = None,
@@ -285,8 +285,6 @@ def attribute(
285285
" results.",
286286
stacklevel=1,
287287
)
288-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
289-
# `TensorOrTupleOfTensorsGeneric`.
290288
is_inputs_tuple = _is_tuple(inputs)
291289

292290
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
@@ -334,8 +332,9 @@ def attribute(
334332
def _attribute(
335333
self,
336334
inputs: Tuple[Tensor, ...],
337-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338-
neuron_selector: Union[int, Tuple[int, ...], Callable],
335+
neuron_selector: Union[
336+
int, Tuple[int, ...], Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor]
337+
],
339338
baselines: Tuple[Union[Tensor, int, float], ...],
340339
target: TargetType = None,
341340
additional_forward_args: Optional[object] = None,

0 commit comments

Comments
 (0)