Skip to content

Commit 7fb54d8

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron_integrated_gradients pyre fixme issues (#1457)
Summary: Pull Request resolved: #1457 Differential Revision: D67523072
1 parent 9a7ef2e commit 7fb54d8

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

captum/_utils/gradient.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from captum._utils.sample_gradient import SampleGradientWrapper
2929
from captum._utils.typing import (
3030
ModuleOrModuleList,
31+
SliceIntType,
3132
TargetType,
3233
TensorOrTupleOfTensorsGeneric,
3334
)
@@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval(
775776

776777
def construct_neuron_grad_fn(
777778
layer: Module,
778-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
779-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
779+
neuron_selector: Union[
780+
int,
781+
Tuple[Union[int, SliceIntType], ...],
782+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
783+
],
780784
device_ids: Union[None, List[int]] = None,
781785
attribute_to_neuron_input: bool = False,
782786
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.

captum/_utils/typing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
TensorLikeList5D,
4242
]
4343

44+
try:
45+
# Subscripted slice syntax is not supported in previous Python versions,
46+
# falling back to slice type.
47+
SliceIntType = slice[int, int, int]
48+
except TypeError:
49+
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
50+
SliceIntType = slice # type: ignore
4451

4552
# Necessary for Python >=3.7 and <3.9!
4653
if TYPE_CHECKING:

captum/attr/_core/neuron/neuron_integrated_gradients.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Callable, List, Optional, Tuple, Union
55

66
from captum._utils.gradient import construct_neuron_grad_fn
7-
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
7+
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
88
from captum.attr._core.integrated_gradients import IntegratedGradients
99
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
1010
from captum.log import log_usage
@@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):
2727

2828
def __init__(
2929
self,
30-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31-
forward_func: Callable,
30+
forward_func: Callable[..., Tensor],
3231
layer: Module,
3332
device_ids: Union[None, List[int]] = None,
3433
multiply_by_inputs: bool = True,
@@ -76,8 +75,11 @@ def __init__(
7675
def attribute(
7776
self,
7877
inputs: TensorOrTupleOfTensorsGeneric,
79-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
80-
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
78+
neuron_selector: Union[
79+
int,
80+
Tuple[Union[int, SliceIntType], ...],
81+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
82+
],
8183
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
8284
additional_forward_args: Optional[object] = None,
8385
n_steps: int = 50,

0 commit comments

Comments
 (0)