Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
13 changes: 4 additions & 9 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class Occlusion(FeatureAblation):
/tensorflow/methods.py#L401
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -58,8 +57,7 @@ def attribute( # type: ignore
] = None,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -377,9 +375,7 @@ def _occlusion_mask(
padded_tensor = torch.nn.functional.pad(
sliding_window_tsr, tuple(pad_values) # type: ignore
)
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and
# `Size`.
return padded_tensor.reshape((1,) + padded_tensor.shape)
return padded_tensor.reshape((1,) + tuple(padded_tensor.shape))

def _get_feature_range_and_mask(
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
Expand All @@ -389,8 +385,7 @@ def _get_feature_range_and_mask(

def _get_feature_counts(
self,
# pyre-fixme[2]: Parameter must be annotated.
inputs,
inputs: TensorOrTupleOfTensorsGeneric,
feature_mask: Tuple[Tensor, ...],
**kwargs: Any,
) -> Tuple[int, ...]:
Expand Down
25 changes: 8 additions & 17 deletions captum/attr/_core/saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable
from typing import Callable

import torch
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
Expand All @@ -13,6 +13,7 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class Saliency(GradientAttribution):
Expand All @@ -25,8 +26,7 @@ class Saliency(GradientAttribution):
https://arxiv.org/abs/1312.6034
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -41,8 +41,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
abs: bool = True,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -124,29 +123,21 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

# No need to format additional_forward_args here.
# They are being formated in the `_run_forward` function in `common.py`
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)
if abs:
attributions = tuple(torch.abs(gradient) for gradient in gradients)
else:
attributions = gradients
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, attributions)
Expand Down
96 changes: 29 additions & 67 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import math
import warnings
from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union
from typing import Callable, cast, Iterable, Sequence, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -56,9 +56,7 @@ def _shape_feature_mask(
f"input shape: {inp.shape}, feature mask shape {mask.shape}"
)
if mask.dim() < inp.dim():
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int,
# ...]` and `Size`.
mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + mask.shape)
mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + tuple(mask.shape))

mask_list.append(mask)

Expand Down Expand Up @@ -89,8 +87,7 @@ class ShapleyValueSampling(PerturbationAttribution):
https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None:
r"""
Args:

Expand All @@ -111,8 +108,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -301,45 +297,25 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
# pyre-fixme[9]: feature_mask has type
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
feature_mask = _format_feature_mask(feature_mask, inputs)
# pyre-fixme[9]: feature_mask has type
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
feature_mask = _shape_feature_mask(feature_mask, inputs)
formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple)
reshaped_feature_mask = _shape_feature_mask(
formatted_feature_mask, inputs_tuple
)

assert (
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
), "Ablations per evaluation must be at least 1."

with torch.no_grad():
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
baselines = _tensorize_baseline(inputs, baselines)
num_examples = inputs[0].shape[0]
baselines = _tensorize_baseline(inputs_tuple, baselines)
num_examples = inputs_tuple[0].shape[0]

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
total_features = _get_max_feature_index(feature_mask) + 1
total_features = _get_max_feature_index(reshaped_feature_mask) + 1

if show_progress:
attr_progress = progress(
Expand All @@ -362,7 +338,7 @@ def attribute(
initial_eval,
num_examples,
perturbations_per_eval,
feature_mask,
reshaped_feature_mask,
allow_multi_outputs=True,
)

Expand All @@ -372,11 +348,11 @@ def attribute(
# attr shape (*output_shape, *input_feature_shape)
total_attrib = [
torch.zeros(
output_shape + input.shape[1:],
tuple(output_shape) + tuple(input.shape[1:]),
dtype=torch.float,
device=inputs[0].device,
device=inputs_tuple[0].device,
)
for input in inputs
for input in inputs_tuple
]

iter_count = 0
Expand All @@ -393,17 +369,11 @@ def attribute(
current_target,
current_masks,
) in self._perturbation_generator(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]`
# but got `TensorOrTupleOfTensorsGeneric`.
inputs,
inputs_tuple,
additional_forward_args,
target,
baselines,
# pyre-fixme[6]: For 5th argument expected
# `TensorOrTupleOfTensorsGeneric` but got
# `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
# typing.Tuple[Tensor, ...]]]]`.
feature_mask,
reshaped_feature_mask,
feature_permutation,
perturbations_per_eval,
):
Expand Down Expand Up @@ -445,10 +415,8 @@ def attribute(
# have the same dim as the mask tensor.
formatted_eval_diff = eval_diff.reshape(
(-1,)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int]` and `Size`.
+ output_shape
+ (len(inputs[j].shape) - 1) * (1,)
+ tuple(output_shape)
+ (len(inputs_tuple[j].shape) - 1) * (1,)
)

# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
Expand All @@ -460,11 +428,9 @@ def attribute(
# )
cur_mask = current_masks[j]
cur_mask = cur_mask.reshape(
cur_mask.shape[:2]
tuple(cur_mask.shape[:2])
+ (len(output_shape) - 1) * (1,)
# pyre-fixme[58]: `+` is not supported for operand types
# `Tuple[int, ...]` and `Size`.
+ cur_mask.shape[2:]
+ tuple(cur_mask.shape[2:])
)

# aggregate n_perturb
Expand Down Expand Up @@ -495,18 +461,16 @@ def attribute_future(self) -> Callable:
"attribute_future is not implemented for ShapleyValueSampling"
)

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def _perturbation_generator(
self,
inputs: Tuple[Tensor, ...],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_args: Any,
additional_args: object,
target: TargetType,
baselines: Tuple[Tensor, ...],
input_masks: TensorOrTupleOfTensorsGeneric,
feature_permutation: Sequence[int],
perturbations_per_eval: int,
) -> Iterable[Tuple[Tuple[Tensor, ...], Any, TargetType, Tuple[Tensor, ...]]]:
) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]:
"""
This method is a generator which yields each perturbation to be evaluated
including inputs, additional_forward_args, targets, and mask.
Expand Down Expand Up @@ -578,9 +542,9 @@ def _perturbation_generator(
combined_masks,
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
def _get_n_evaluations(
self, total_features: int, n_samples: int, perturbations_per_eval: int
) -> int:
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * n_samples

Expand Down Expand Up @@ -642,8 +606,7 @@ class ShapleyValues(ShapleyValueSampling):
evaluations, and we plan to add this approach in the future.
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None:
r"""
Args:

Expand All @@ -664,8 +627,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down
Loading
Loading