Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
return isinstance(inputs, tuple)

Expand Down
16 changes: 1 addition & 15 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@

# pyre-strict

from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
Expand Down
25 changes: 10 additions & 15 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import Any, Callable, Generator, Tuple, Union
from typing import Callable, cast, Generator, Tuple, Union

import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
Expand All @@ -27,8 +27,7 @@ class KernelShap(Lime):
https://arxiv.org/abs/1705.07874
"""

# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def __init__(self, forward_func: Callable) -> None:
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:

Expand All @@ -50,8 +49,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -279,10 +277,7 @@ def attribute( # type: ignore
)
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
denom = num_features_list * (num_interp_features - num_features_list)
# pyre-fixme[58]: `/` is not supported for operand types
# `int` and `torch._tensor.Tensor`.
probs = (num_interp_features - 1) / denom
# pyre-fixme[16]: `float` has no attribute `__setitem__`.
probs = torch.tensor((num_interp_features - 1)) / denom
probs[0] = 0.0
return self._attribute_kwargs(
inputs=inputs,
Expand All @@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel(
_,
__,
interpretable_sample: Tensor,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Tensor:
assert (
"num_interp_features" in kwargs
Expand All @@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel(
def kernel_shap_perturb_generator(
self,
original_inp: Union[Tensor, Tuple[Tensor, ...]],
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Generator[Tensor, None, None]:
r"""
Perturbations are sampled by the following process:
Expand Down Expand Up @@ -361,11 +354,13 @@ def kernel_shap_perturb_generator(
device = original_inp.device
else:
device = original_inp[0].device
num_features = kwargs["num_interp_features"]
num_features = cast(int, kwargs["num_interp_features"])
yield torch.ones(1, num_features, device=device, dtype=torch.long)
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
while True:
num_selected_features = kwargs["num_select_distribution"].sample()
num_selected_features = cast(
Categorical, kwargs["num_select_distribution"]
).sample()
rand_vals = torch.randn(1, num_features)
threshold = torch.kthvalue(
rand_vals, num_features - num_selected_features
Expand Down
100 changes: 43 additions & 57 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from collections.abc import Iterator
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from typing import Any, Callable, cast, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -23,12 +23,7 @@
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
Expand Down Expand Up @@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
interpretable_model: Model,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
similarity_func: Callable,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perturb_func: Callable,
similarity_func: Callable[
...,
Union[float, Tensor],
],
perturb_func: Callable[..., object],
perturb_interpretable_space: bool,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
from_interp_rep_transform: Optional[Callable],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
to_interp_rep_transform: Optional[Callable],
from_interp_rep_transform: Optional[
Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
],
to_interp_rep_transform: Optional[Callable[..., Tensor]],
) -> None:
r"""

Expand Down Expand Up @@ -249,13 +244,11 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
show_progress: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> Tensor:
r"""
This method attributes the output of the model with given target index
Expand Down Expand Up @@ -551,7 +544,7 @@ def generate_perturbation() -> (
curr_sample, inputs, **kwargs
)

return interpretable_inp, curr_model_input
return interpretable_inp, curr_model_input # type: ignore

return generate_perturbation

Expand All @@ -568,8 +561,7 @@ def _evaluate_batch(
self,
curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
expanded_target: TargetType,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
expanded_additional_args: Any,
expanded_additional_args: object,
device: torch.device,
) -> Tensor:
model_out = _run_forward(
Expand Down Expand Up @@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
def get_exp_kernel_similarity_function(
distance_mode: str = "cosine",
kernel_width: float = 1.0,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
) -> Callable:
) -> Callable[..., float]:
r"""
This method constructs an appropriate similarity function to compute
weights for perturbed sample in LIME. Distance between the original
Expand Down Expand Up @@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
return default_exp_kernel


# pyre-fixme[2]: Parameter must be annotated.
def default_perturb_func(original_inp, **kwargs) -> Tensor:
def default_perturb_func(
original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object
) -> Tensor:
assert (
"num_interp_features" in kwargs
), "Must provide num_interp_features to use default interpretable sampling function"
Expand All @@ -690,25 +682,25 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor:
else:
device = original_inp[0].device

probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5
return torch.bernoulli(probs).to(device=device).long()


def construct_feature_mask(
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask_tuple: Tuple[Tensor, ...]
if feature_mask is None:
feature_mask, num_interp_features = _construct_default_feature_mask(
feature_mask_tuple, num_interp_features = _construct_default_feature_mask(
formatted_inputs
)
else:
feature_mask = _format_tensor_into_tuples(feature_mask)
feature_mask_tuple = _format_tensor_into_tuples(feature_mask)
min_interp_features = int(
min(
torch.min(single_mask).item()
# pyre-fixme[16]: `None` has no attribute `__iter__`.
for single_mask in feature_mask
for single_mask in feature_mask_tuple
if single_mask.numel()
)
)
Expand All @@ -718,14 +710,12 @@ def construct_feature_mask(
" start at 0.",
stacklevel=2,
)
feature_mask = tuple(
single_mask - min_interp_features for single_mask in feature_mask
feature_mask_tuple = tuple(
single_mask - min_interp_features for single_mask in feature_mask_tuple
)

# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `Optional[typing.Tuple[typing.Any, ...]]`.
num_interp_features = _get_max_feature_index(feature_mask) + 1
return feature_mask, num_interp_features
num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1
return feature_mask_tuple, num_interp_features


class Lime(LimeBase):
Expand Down Expand Up @@ -766,8 +756,7 @@ class Lime(LimeBase):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
interpretable_model: Optional[Model] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
similarity_func: Optional[Callable] = None,
Expand Down Expand Up @@ -887,8 +876,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
Expand Down Expand Up @@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: object = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
**kwargs: object,
) -> TensorOrTupleOfTensorsGeneric:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
bsz = formatted_inputs[0].shape[0]
Expand Down Expand Up @@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore
return coefs

@typing.overload
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
# all possible arguments of overload defined on line `1201`.
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
) -> Tuple[Tensor, ...]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
# all possible arguments of overload defined on line `1211`.
def _convert_output_shape( # type: ignore
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False],
) -> Tensor: ...

@typing.overload
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
Expand Down
Loading
Loading