Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
ModuleOrModuleList,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
Expand Down Expand Up @@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval(

def construct_neuron_grad_fn(
layer: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
device_ids: Union[None, List[int]] = None,
attribute_to_neuron_input: bool = False,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
7 changes: 7 additions & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
TensorLikeList5D,
]

try:
# Subscripted slice syntax is not supported in previous Python versions,
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore

# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
Expand Down
12 changes: 7 additions & 5 deletions captum/attr/_core/neuron/neuron_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, List, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
Expand All @@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -76,8 +75,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
additional_forward_args: Optional[object] = None,
n_steps: int = 50,
Expand Down
Loading