Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
TupleOrTensorOrBoolGeneric = TypeVar(
"TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool
)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
Expand All @@ -46,7 +47,7 @@
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
# pyre-ignore[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore

# Necessary for Python >=3.7 and <3.9!
Expand Down
72 changes: 40 additions & 32 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from captum.log import log_usage
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.scatter_gather import scatter


Expand All @@ -58,8 +59,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: ModuleOrModuleList,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -128,8 +128,7 @@ def _make_gradient_func(
) -> Callable[..., Tuple[Tensor, ...]]:

def _gradient_func(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
forward_fn: Callable[..., Tensor],
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand All @@ -146,28 +145,21 @@ def _gradient_func(
target_gpus=self.device_ids,
)

scattered_inputs_dict = {
scattered_inputs_dict: Dict[
torch.device, Union[Tensor, Tuple[Tensor, ...]]
] = {
scattered_input[0].device: scattered_input
for scattered_input in scattered_inputs
}

with torch.autograd.set_grad_enabled(True):

# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
# annotated.
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
# annotated.
# pyre-fixme[3]: Return type must be annotated.
def layer_forward_hook(
# pyre-fixme[2]: Parameter must be annotated.
module,
# pyre-fixme[2]: Parameter must be annotated.
hook_inputs,
# pyre-fixme[2]: Parameter must be annotated.
hook_outputs=None,
# pyre-fixme[2]: Parameter must be annotated.
layer_idx=0,
):
module: Module,
hook_inputs: Union[Tensor, Tuple[Tensor, ...]],
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
layer_idx: int = 0,
) -> Union[Tensor, Tuple[Tensor, ...]]:
device = _extract_device(module, hook_inputs, hook_outputs)
is_layer_tuple = (
isinstance(hook_outputs, tuple)
Expand All @@ -177,11 +169,14 @@ def layer_forward_hook(
)

if is_layer_tuple:
return scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
]
return cast(
Union[Tensor, Tuple[Tensor, ...]],
scattered_inputs_dict[device][
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
layer_idx + 1
]
],
)

return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]

Expand Down Expand Up @@ -502,11 +497,22 @@ def attribute(
additional_forward_args
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def flatten_tuple(tup):
def flatten_tuple(tup: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
return tuple(
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
cast(
List[Tensor],
sum(
(
(
list(x)
if isinstance(x, (tuple, list))
else cast(List[Tensor], [x])
)
for x in tup
),
[],
),
)
)

if self.device_ids is None:
Expand All @@ -520,16 +526,18 @@ def flatten_tuple(tup):
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
)

input_layer_list: List[Tuple[Tensor, ...]]
# if we have one output
if not isinstance(self.layer, list):
inputs_layer = (inputs_layer,)
input_layer_list = [cast(Tuple[Tensor, ...], inputs_layer)]
else:
input_layer_list = inputs_layer

num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in input_layer_list]
num_outputs_cumsum = torch.cumsum(
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
)
inputs_layer = flatten_tuple(inputs_layer)
inputs_layer = flatten_tuple(input_layer_list)

baselines_layer = _forward_layer_eval(
self.forward_func,
Expand Down
6 changes: 4 additions & 2 deletions tests/helpers/evaluate_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import cast, Dict

import torch

from captum._utils.models.linear_model.model import LinearModel
from torch import Tensor
from torch.utils.data import DataLoader


# pyre-fixme[2]: Parameter must be annotated.
def evaluate(test_data, classifier) -> Dict[str, Tensor]:
def evaluate(test_data: DataLoader, classifier: LinearModel) -> Dict[str, Tensor]:
classifier.eval()

l1_loss = 0.0
Expand Down
19 changes: 8 additions & 11 deletions tests/metrics/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict

import typing
from typing import Callable, cast, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
Expand All @@ -28,19 +28,15 @@


@typing.overload
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
# arguments of overload defined on line `32`.
def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
# arguments of overload defined on line `28`.
def _perturb_func(inputs: Tensor) -> Tensor: ...


def _perturb_func(
inputs: TensorOrTupleOfTensorsGeneric,
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]:
def perturb_ratio(input: Tensor) -> Tensor:
return (
Expand All @@ -55,7 +51,7 @@ def perturb_ratio(input: Tensor) -> Tensor:
input1 = inputs[0]
input2 = inputs[1]
else:
input1 = cast(Tensor, inputs)
input1 = inputs

perturbed_input1 = input1 + perturb_ratio(input1)

Expand Down Expand Up @@ -283,12 +279,13 @@ def test_classification_sensitivity_tpl_target_w_baseline(self) -> None:

def sensitivity_max_assert(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
expl_func: Callable,
expl_func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
inputs: TensorOrTupleOfTensorsGeneric,
expected_sensitivity: Tensor,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perturb_func: Callable = _perturb_func,
perturb_func: Union[
Callable[[Tensor], Tensor],
Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]],
] = _perturb_func,
n_perturb_samples: int = 5,
max_examples_per_batch: Optional[int] = None,
baselines: Optional[BaselineType] = None,
Expand Down
Loading