Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 5 additions & 69 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation):
of examples to compute attributions and cannot be performed on a single example.

By default, each scalar value within
each input tensor is taken as a feature and shuffled independently, *unless*
attribute() is called with enable_cross_tensor_attribution=True. Passing
a feature mask, allows grouping features to be shuffled together.
each input tensor is taken as a feature and shuffled independently. Passing
a feature mask allows grouping features to be shuffled together (including
features defined across different input tensors).
Each input scalar in the group will be given the same attribution value
equal to the change in target as a result of shuffling the entire feature
group.
Expand Down Expand Up @@ -92,12 +92,6 @@ def __init__(
"""
FeatureAblation.__init__(self, forward_func=forward_func)
self.perm_func = perm_func
# Minimum number of elements needed in each input tensor, when
# `enable_cross_tensor_attribution` is False, otherwise the
# attribution for the tensor will be skipped. Set to 1 to throw if any
# input tensors only have one example
self._min_examples_per_batch = 2
# Similar to above, when `enable_cross_tensor_attribution` is True.
# Considering the case when we permute multiple input tensors at once
# through `feature_mask`, we disregard the feature group if the 0th
# dim of *any* input tensor in the group is less than
Expand All @@ -115,7 +109,6 @@ def attribute( # type: ignore
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
enable_cross_tensor_attribution: bool = True,
**kwargs: Any,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Expand Down Expand Up @@ -187,18 +180,12 @@ def attribute( # type: ignore
input tensor. Each tensor should contain integers in
the range 0 to num_features - 1, and indices
corresponding to the same feature should have the
same value. Note that features within each input
tensor are ablated independently (not across
tensors), unless enable_cross_tensor_attribution is
True.

same value.
The first dimension of each mask must be 1, as we require
to have the same group of features for each input sample.

If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature, which
is permuted independently, unless
enable_cross_tensor_attribution is True.
each scalar within a tensor as a separate feature.
Default: None
perturbations_per_eval (int, optional): Allows permutations
of multiple features to be processed simultaneously
Expand All @@ -217,10 +204,6 @@ def attribute( # type: ignore
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
enable_cross_tensor_attribution (bool, optional): If True, then
features can be grouped across input tensors depending on
the values in the feature mask.
Default: False
**kwargs (Any, optional): Any additional arguments used by child
classes of :class:`.FeatureAblation` (such as
:class:`.Occlusion`) to construct ablations. These
Expand Down Expand Up @@ -292,7 +275,6 @@ def attribute( # type: ignore
feature_mask=feature_mask,
perturbations_per_eval=perturbations_per_eval,
show_progress=show_progress,
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
**kwargs,
)

Expand All @@ -304,7 +286,6 @@ def attribute_future(
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
enable_cross_tensor_attribution: bool = True,
**kwargs: Any,
) -> Future[TensorOrTupleOfTensorsGeneric]:
"""
Expand All @@ -321,54 +302,9 @@ def attribute_future(
feature_mask=feature_mask,
perturbations_per_eval=perturbations_per_eval,
show_progress=show_progress,
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
**kwargs,
)

def _construct_ablated_input(
self,
expanded_input: Tensor,
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
baseline: Union[None, float, Tensor],
start_feature: int,
end_feature: int,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""
This function permutes the features of `expanded_input` with a given
feature mask and feature range. Permutation occurs via calling
`self.perm_func` across each batch within `expanded_input`. As with
`FeatureAblation._construct_ablated_input`:
- `expanded_input.shape = (num_features, num_examples, ...)`
- `num_features = end_feature - start_feature` (i.e. start and end is a
half-closed interval)
- `input_mask` is a tensor of the same shape as one input, which
describes the locations of each feature via their "index"

Since `baselines` is set to None for `FeatureAblation.attribute, this
will be the zero tensor, however, it is not used.
"""
assert (
input_mask is not None
and not isinstance(input_mask, tuple)
and input_mask.shape[0] == 1
), (
"input_mask.shape[0] != 1: pass in one mask in order to permute"
"the same features for each input"
)
current_mask = torch.stack(
[input_mask == j for j in range(start_feature, end_feature)], dim=0
).bool()
current_mask = current_mask.to(expanded_input.device)

output = torch.stack(
[
self.perm_func(x, mask.squeeze(0))
for x, mask in zip(expanded_input, current_mask)
]
)
return output, current_mask

def _construct_ablated_input_across_tensors(
self,
inputs: Tuple[Tensor, ...],
Expand Down
Loading
Loading