Skip to content

Commit 6e03bab

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron feature ablation pyre fixme issues (#1462)
Summary: Pull Request resolved: #1462 Differential Revision: D67705096
1 parent 7f63eec commit 6e03bab

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

captum/attr/_core/neuron/neuron_feature_ablation.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
from captum._utils.common import _verify_select_neuron
88
from captum._utils.gradient import _forward_layer_eval
9-
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
9+
from captum._utils.typing import (
10+
BaselineType,
11+
SliceIntType,
12+
TensorOrTupleOfTensorsGeneric,
13+
)
1014
from captum.attr._core.feature_ablation import FeatureAblation
1115
from captum.attr._utils.attribution import NeuronAttribution, PerturbationAttribution
1216
from captum.log import log_usage
@@ -31,8 +35,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution):
3135

3236
def __init__(
3337
self,
34-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
35-
forward_func: Callable,
38+
forward_func: Callable[..., Union[int, float, Tensor]],
3639
layer: Module,
3740
device_ids: Union[None, List[int]] = None,
3841
) -> None:
@@ -61,8 +64,11 @@ def __init__(
6164
def attribute(
6265
self,
6366
inputs: TensorOrTupleOfTensorsGeneric,
64-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
65-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
67+
neuron_selector: Union[
68+
int,
69+
Tuple[Union[int, SliceIntType], ...],
70+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
71+
],
6672
baselines: BaselineType = None,
6773
additional_forward_args: Optional[object] = None,
6874
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
@@ -250,8 +256,7 @@ def attribute(
250256
>>> feature_mask=feature_mask)
251257
"""
252258

253-
# pyre-fixme[3]: Return type must be annotated.
254-
def neuron_forward_func(*args: Any):
259+
def neuron_forward_func(*args: Any) -> Tensor:
255260
with torch.no_grad():
256261
layer_eval = _forward_layer_eval(
257262
self.forward_func,

0 commit comments

Comments
 (0)