Skip to content
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
27 changes: 8 additions & 19 deletions captum/attr/_core/guided_backprop_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import warnings
from typing import Any, Callable, List, Tuple, Union
from typing import Callable, List, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -45,8 +45,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Computes attribution by overriding relu gradients. Based on constructor
Expand All @@ -58,16 +57,10 @@ def attribute(

# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

# set hooks for overriding ReLU gradients
warnings.warn(
Expand All @@ -79,14 +72,12 @@ def attribute(
self.model.apply(self._register_hooks)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)
finally:
self._remove_hooks()

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, gradients)
Expand Down Expand Up @@ -155,8 +146,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -265,8 +255,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down
21 changes: 7 additions & 14 deletions captum/attr/_core/guided_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import warnings
from typing import Any, List, Union
from typing import List, Union

import torch
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
Expand Down Expand Up @@ -72,8 +72,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
interpolate_mode: str = "nearest",
attribute_to_layer_input: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -181,15 +180,11 @@ def attribute(
>>> # attribution size matches input size, Nx3x32x32
>>> attribution = guided_gc.attribute(input, 3)
"""
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
grad_cam_attr = self.grad_cam.attribute.__wrapped__(
self.grad_cam, # self
inputs=inputs,
inputs=inputs_tuple,
target=target,
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
Expand All @@ -204,20 +199,18 @@ def attribute(

guided_backprop_attr = self.guided_backprop.attribute.__wrapped__(
self.guided_backprop, # self
inputs=inputs,
inputs=inputs_tuple,
target=target,
additional_forward_args=additional_forward_args,
)
output_attr: List[Tensor] = []
for i in range(len(inputs)):
for i in range(len(inputs_tuple)):
try:
output_attr.append(
guided_backprop_attr[i]
* LayerAttribution.interpolate(
grad_cam_attr,
# pyre-fixme[6]: For 2nd argument expected `Union[int,
# typing.Tuple[int, ...]]` but got `Size`.
inputs[i].shape[2:],
tuple(inputs_tuple[i].shape[2:]),
interpolate_mode=interpolate_mode,
)
)
Expand Down
27 changes: 9 additions & 18 deletions captum/attr/_core/input_x_gradient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable
from typing import Callable

from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
from captum._utils.gradient import (
Expand All @@ -11,6 +11,7 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class InputXGradient(GradientAttribution):
Expand All @@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution):
https://arxiv.org/abs/1605.01713
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -35,8 +35,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -113,28 +112,20 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)

attributions = tuple(
input * gradient for input, gradient in zip(inputs, gradients)
input * gradient for input, gradient in zip(inputs_tuple, gradients)
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, attributions)
Expand Down
45 changes: 14 additions & 31 deletions captum/attr/_core/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -12,12 +12,7 @@
_format_output,
_is_tuple,
)
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution
from captum.attr._utils.batching import _batch_attribution
Expand Down Expand Up @@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
multiply_by_inputs: bool = True,
) -> None:
r"""
Expand Down Expand Up @@ -80,21 +74,16 @@ def __init__(
# and when return_convergence_delta is True, the return type is
# a tuple with both attributions and deltas.
@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `95`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

Expand All @@ -111,9 +100,6 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
) -> TensorOrTupleOfTensorsGeneric: ...

Expand Down Expand Up @@ -275,37 +261,35 @@ def attribute( # type: ignore
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
formatted_inputs, formatted_baselines = _format_input_baseline(
inputs, baselines
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
_validate_input(inputs, baselines, n_steps, method)
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)

if internal_batch_size is not None:
num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]
attributions = _batch_attribution(
self,
num_examples,
internal_batch_size,
n_steps,
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
method=method,
)
else:
attributions = self._attribute(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
n_steps=n_steps,
Expand Down Expand Up @@ -344,8 +328,7 @@ def _attribute(
inputs: Tuple[Tensor, ...],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
Expand Down
Loading
Loading