Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
27 changes: 9 additions & 18 deletions captum/attr/_core/input_x_gradient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable
from typing import Callable

from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
from captum._utils.gradient import (
Expand All @@ -11,6 +11,7 @@
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class InputXGradient(GradientAttribution):
Expand All @@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution):
https://arxiv.org/abs/1605.01713
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -35,8 +35,7 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -113,28 +112,20 @@ def attribute(
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
gradient_mask = apply_gradient_requirements(inputs)
inputs_tuple = _format_tensor_into_tuples(inputs)
gradient_mask = apply_gradient_requirements(inputs_tuple)

gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
self.forward_func, inputs_tuple, target, additional_forward_args
)

attributions = tuple(
input * gradient for input, gradient in zip(inputs, gradients)
input * gradient for input, gradient in zip(inputs_tuple, gradients)
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
undo_gradient_requirements(inputs, gradient_mask)
undo_gradient_requirements(inputs_tuple, gradient_mask)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return _format_output(is_inputs_tuple, attributions)
Expand Down
45 changes: 14 additions & 31 deletions captum/attr/_core/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -12,12 +12,7 @@
_format_output,
_is_tuple,
)
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution
from captum.attr._utils.batching import _batch_attribution
Expand Down Expand Up @@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
multiply_by_inputs: bool = True,
) -> None:
r"""
Expand Down Expand Up @@ -80,21 +74,16 @@ def __init__(
# and when return_convergence_delta is True, the return type is
# a tuple with both attributions and deltas.
@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `95`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

Expand All @@ -111,9 +100,6 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
) -> TensorOrTupleOfTensorsGeneric: ...

Expand Down Expand Up @@ -275,37 +261,35 @@ def attribute( # type: ignore
"""
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
formatted_inputs, formatted_baselines = _format_input_baseline(
inputs, baselines
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
_validate_input(inputs, baselines, n_steps, method)
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)

if internal_batch_size is not None:
num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]
attributions = _batch_attribution(
self,
num_examples,
internal_batch_size,
n_steps,
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
method=method,
)
else:
attributions = self._attribute(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
n_steps=n_steps,
Expand Down Expand Up @@ -344,8 +328,7 @@ def _attribute(
inputs: Tuple[Tensor, ...],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_steps: int = 50,
method: str = "gausslegendre",
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
Expand Down
25 changes: 10 additions & 15 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable, Generator, Tuple, Union
from typing import Callable, cast, Generator, Tuple, Union

import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
Expand All @@ -27,8 +27,7 @@ class KernelShap(Lime):
https://arxiv.org/abs/1705.07874
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -50,8 +49,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -279,10 +277,7 @@ def attribute( # type: ignore
)
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
denom = num_features_list * (num_interp_features - num_features_list)
# pyre-fixme[58]: `/` is not supported for operand types
# `int` and `torch._tensor.Tensor`.
probs = (num_interp_features - 1) / denom
# pyre-fixme[16]: `float` has no attribute `__setitem__`.
probs = torch.tensor((num_interp_features - 1)) / denom
probs[0] = 0.0
return self._attribute_kwargs(
inputs=inputs,
Expand All @@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel(
_,
__,
interpretable_sample: Tensor,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Tensor:
assert (
"num_interp_features" in kwargs
Expand All @@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel(
def kernel_shap_perturb_generator(
self,
original_inp: Union[Tensor, Tuple[Tensor, ...]],
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Generator[Tensor, None, None]:
r"""
Perturbations are sampled by the following process:
Expand Down Expand Up @@ -361,11 +354,13 @@ def kernel_shap_perturb_generator(
device = original_inp.device
else:
device = original_inp[0].device
num_features = kwargs["num_interp_features"]
num_features = cast(int, kwargs["num_interp_features"])
yield torch.ones(1, num_features, device=device, dtype=torch.long)
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
while True:
num_selected_features = kwargs["num_select_distribution"].sample()
num_selected_features = cast(
Categorical, kwargs["num_select_distribution"]
).sample()
rand_vals = torch.randn(1, num_features)
threshold = torch.kthvalue(
rand_vals, num_features - num_selected_features
Expand Down
Loading
Loading