From b42ec61a79f324eeae869e71802c9580d2be105f Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 20 Feb 2025 16:22:10 -0800 Subject: [PATCH 1/3] Support feature grouping across input tensors for FeatureAblation (#1497) Summary: Basic support; doesn't currently support multiple perturbations per eval Reviewed By: cyrjano, vivekmig Differential Revision: D69531512 --- captum/attr/_core/feature_ablation.py | 351 +++++++++++++++++++++----- setup.py | 2 +- tests/attr/test_feature_ablation.py | 151 +++++++---- 3 files changed, 395 insertions(+), 109 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index e5e60bb465..5375dbb638 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,7 @@ # pyre-strict import math -from typing import Any, Callable, cast, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union import torch from captum._utils.common import ( @@ -16,7 +16,7 @@ _run_forward, ) from captum._utils.exceptions import FeatureAblationFutureError -from captum._utils.progress import progress +from captum._utils.progress import progress, SimpleProgress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.common import _format_input_baseline @@ -24,6 +24,13 @@ from torch import dtype, Tensor from torch.futures import collect_all, Future +try: + from tqdm.auto import tqdm +except ImportError: + tqdm = None + +IterableType = TypeVar("IterableType") + class FeatureAblation(PerturbationAttribution): r""" @@ -79,6 +86,7 @@ def attribute( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, + enable_cross_tensor_attribution: bool = False, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -171,7 +179,8 @@ def attribute( - 1, and indices corresponding to the same feature should have the same value. Note that features within each input tensor are ablated - independently (not across tensors). + independently (not across tensors), unless + enable_cross_tensor_attribution is True. If the forward function returns a single scalar per batch, we enforce that the first dimension of each mask must be 1, since attributions are returned batch-wise rather than per @@ -179,7 +188,7 @@ def attribute( same features (indices) in each input example. If None, then a feature mask is constructed which assigns each scalar within a tensor as a separate feature, which - is ablated independently. + is ablated independently by default. Default: None perturbations_per_eval (int, optional): Allows ablation of multiple features to be processed simultaneously in one call to @@ -202,6 +211,10 @@ def attribute( (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False + enable_cross_tensor_attribution (bool, optional): If True, features + IDs in feature_mask are global IDs across input tensors, + and are ablated together. + Default: False **kwargs (Any, optional): Any additional arguments used by child classes of FeatureAblation (such as Occlusion) to construct ablations. These arguments are ignored when using @@ -274,10 +287,12 @@ def attribute( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Perturbations per evaluation must be an integer and at least 1." with torch.no_grad(): + attr_progress = None if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, + enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -291,7 +306,7 @@ def attribute( target, formatted_additional_forward_args, ) - if show_progress: + if attr_progress is not None: attr_progress.update() total_attrib: List[Tensor] = [] @@ -318,69 +333,239 @@ def attribute( formatted_inputs, ) - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - # Skip any empty input tensors - if torch.numel(formatted_inputs[i]) == 0: - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, + if enable_cross_tensor_attribution: + total_attrib, weights = self._attribute_with_cross_tensor_feature_masks( + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + attr_progress, + flattened_initial_eval, + n_outputs, + total_attrib, + weights, + attrib_type, + **kwargs, + ) + else: + total_attrib, weights = self._attribute_with_independent_feature_masks( formatted_inputs, formatted_additional_forward_args, target, baselines, formatted_feature_mask, + num_examples, perturbations_per_eval, + attr_progress, + initial_eval, + flattened_initial_eval, + n_outputs, + total_attrib, + weights, + attrib_type, **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) + ) - if show_progress: - attr_progress.update() + if attr_progress is not None: + attr_progress.close() - assert not isinstance(modified_eval, torch.Future), ( - "when use_futures is True, modified_eval should have " - f"non-Future type rather than {type(modified_eval)}" - ) - total_attrib, weights = self._process_ablated_out( - modified_eval, - current_inputs, - current_mask, - perturbations_per_eval, - num_examples, - initial_eval, - flattened_initial_eval, - formatted_inputs, - n_outputs, - total_attrib, - weights, - i, - attrib_type, - ) + # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: + # [Tensor, typing.Tuple[Tensor, ...]]]` + # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. + return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long - if show_progress: - attr_progress.close() + def _attribute_with_independent_feature_masks( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + num_examples: int, + perturbations_per_eval: int, + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + initial_eval: Tensor, + flattened_initial_eval: Tensor, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + **kwargs: Any, + ) -> Tuple[List[Tensor], List[Tensor]]: + # Iterate through each feature tensor for ablation + for i in range(len(formatted_inputs)): + # Skip any empty input tensors + if torch.numel(formatted_inputs[i]) == 0: + continue + + for ( + current_inputs, + current_add_args, + current_target, + current_mask, + ) in self._ith_input_ablation_generator( + i, + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + perturbations_per_eval, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + + if attr_progress is not None: + attr_progress.update() + + assert not isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"non-Future type rather than {type(modified_eval)}" + ) + total_attrib, weights = self._process_ablated_out( + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + formatted_inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) + return total_attrib, weights + + def _attribute_with_cross_tensor_feature_masks( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + flattened_initial_eval: Tensor, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + **kwargs: Any, + ) -> Tuple[List[Tensor], List[Tensor]]: + for ( + current_inputs, + current_mask, + ) in self._ablation_generator( + formatted_inputs, + baselines, + formatted_feature_mask, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval = _run_forward( + self.forward_func, + current_inputs, + target, + formatted_additional_forward_args, + ) - # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: - # [Tensor, typing.Tuple[Tensor, ...]]]` - # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. - return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long + if attr_progress is not None: + attr_progress.update() + + assert not isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"non-Future type rather than {type(modified_eval)}" + ) + + total_attrib, weights = self._process_ablated_out_full( + modified_eval, + current_mask, + flattened_initial_eval, + formatted_inputs, + n_outputs, + total_attrib, + weights, + attrib_type, + ) + return total_attrib, weights + + def _ablation_generator( + self, + inputs: Tuple[Tensor, ...], + baselines: BaselineType, + input_mask: Tuple[Tensor, ...], + **kwargs: Any, + ) -> Generator[ + Tuple[ + Tuple[Tensor, ...], + Tuple[Tensor, ...], + ], + None, + None, + ]: + unique_feature_ids = torch.unique( + torch.cat([mask.flatten() for mask in input_mask]) + ).tolist() + + if isinstance(baselines, torch.Tensor): + baselines = baselines.reshape((1,) + tuple(baselines.shape)) + + # Process one feature per time, rather than processing every input tensor + for feature_idx in unique_feature_ids: + ablated_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + inputs, input_mask, baselines, feature_idx + ) + ) + yield ablated_inputs, current_masks + + def _construct_ablated_input_across_tensors( + self, + inputs: Tuple[Tensor, ...], + input_mask: Tuple[Tensor, ...], + baselines: BaselineType, + feature_idx: int, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + + ablated_inputs = [] + current_masks = [] + for i, input_tensor in enumerate(inputs): + mask = input_mask[i] + tensor_mask = mask == feature_idx + if not tensor_mask.any(): + ablated_inputs.append(input_tensor) + current_masks.append(torch.zeros_like(tensor_mask)) + continue + tensor_mask = tensor_mask.to(input_tensor.device).long() + baseline = baselines[i] if isinstance(baselines, tuple) else baselines + if isinstance(baseline, torch.Tensor): + baseline = baseline.reshape( + (1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape) + ) + assert baseline is not None, "baseline must be provided" + ablated_input = ( + input_tensor * (1 - tensor_mask).to(input_tensor.dtype) + ) + (baseline * tensor_mask.to(input_tensor.dtype)) + ablated_inputs.append(ablated_input) + current_masks.append(tensor_mask) + return tuple(ablated_inputs), tuple(current_masks) def _initial_eval_to_processed_initial_eval_fut( self, initial_eval: Future[Tensor], formatted_inputs: Tuple[Tensor, ...] @@ -572,6 +757,7 @@ def _attribute_progress_setup( self, formatted_inputs: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], + enable_cross_tensor_attribution: bool, perturbations_per_eval: int, **kwargs: Any, ): @@ -579,9 +765,13 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - sum(math.ceil(count / perturbations_per_eval) for count in feature_counts) - + 1 - ) # add 1 for the initial eval + int(sum(feature_counts)) + if enable_cross_tensor_attribution + else sum( + math.ceil(count / perturbations_per_eval) for count in feature_counts + ) + ) + total_forwards += 1 # add 1 for the initial eval attr_progress = progress( desc=f"{self.get_name()} attribution", total=total_forwards ) @@ -808,10 +998,7 @@ def _construct_ablated_input( current_mask = current_mask.to(expanded_input.device) assert baseline is not None, "baseline must be provided" ablated_tensor = ( - expanded_input - * (1 - current_mask).to(expanded_input.dtype) - # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, - # Tensor]` and `Tensor`. + expanded_input * (1 - current_mask).to(expanded_input.dtype) ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask @@ -983,6 +1170,42 @@ def _process_ablated_out( total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0) return total_attrib, weights + def _process_ablated_out_full( + self, + modified_eval: Tensor, + current_mask: Tuple[Tensor, ...], + flattened_initial_eval: Tensor, + inputs: TensorOrTupleOfTensorsGeneric, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + ) -> Tuple[List[Tensor], List[Tensor]]: + modified_eval = self._parse_forward_out(modified_eval) + + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = flattened_initial_eval - modified_eval + eval_diff_shape = eval_diff.shape + + # append the shape of one input example + # to make it broadcastable to mask + + if self.use_weights: + for weight, mask in zip(weights, current_mask): + weight += mask.float().sum(dim=0) + for i, mask in enumerate(current_mask): + if inputs[i].numel() == 0: + continue + eval_diff = eval_diff.reshape( + eval_diff_shape + (inputs[i].dim() - 1) * (1,) + ) + eval_diff = eval_diff.to(total_attrib[i].device) + total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights + def _fut_tuple_to_accumulate_fut_list( self, total_attrib: List[Tensor], diff --git a/setup.py b/setup.py index 588f8f3a28..ee72439439 100755 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ def report(*args): + [ "black", "flake8", - "sphinx", + "sphinx<8.2.0", "sphinx-autodoc-typehints", "sphinxcontrib-katex", "mypy>=0.760", diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 3646bd6c58..c8f9802d6a 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -170,8 +170,8 @@ def test_multi_input_ablation_with_mask(self) -> None: inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) expected = ( [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], @@ -207,8 +207,8 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_multi_input_ablation_with_mask_nt(self) -> None: - ablation_algo = NoiseTunnel(FeatureAblation(BasicModel_MultiLayer_MultiInput())) + def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) @@ -220,6 +220,66 @@ def test_multi_input_ablation_with_mask_nt(self) -> None: [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], ) + expected_cross_tensor = ( + [[1092.0, 1092.0, 1092.0], [260.0, 600.0, 260.0]], + [[80.0, 1092.0, 160.0], [260.0, 600.0, 0.0]], + [[80.0, 1092.0, 160.0], [260.0, 260.0, 260.0]], + ) + for test_enable_cross_tensor_attribution, expected_out in [ + (True, expected_cross_tensor), + (False, expected), + ]: + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_out, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + test_enable_cross_tensor_attribution=[ + test_enable_cross_tensor_attribution + ], + ) + + expected_with_baseline = ( + [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], + [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], + [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + ) + expected_cross_tensor_with_baseline = ( + [[1040.0, 1040.0, 1040.0], [184.0, 580.0, 184.0]], + [[52.0, 1040.0, 132.0], [184.0, 580.0, -12.0]], + [[52.0, 1040.0, 132.0], [184.0, 184.0, 184.0]], + ) + for test_enable_cross_tensor_attribution, expected_out in [ + (True, expected_cross_tensor_with_baseline), + (False, expected_with_baseline), + ]: + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_out, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + test_enable_cross_tensor_attribution=[ + test_enable_cross_tensor_attribution + ], + ) + + def test_multi_input_ablation_with_mask_nt(self) -> None: + ablation_algo = NoiseTunnel(FeatureAblation(BasicModel_MultiLayer_MultiInput())) + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) + expected = ( + [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], + [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], + [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], + ) self._ablation_test_assert( ablation_algo, (inp1, inp2, inp3), @@ -694,8 +754,8 @@ def _multi_input_batch_scalar_ablation_assert( inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2]]) + mask2 = torch.tensor([[0, 3, 2]]) + mask3 = torch.tensor([[4, 5, 6]]) expected = ( torch.tensor([[1784, 1784, 1784]], dtype=dtype), torch.tensor([[160, 1200, 240]], dtype=dtype), @@ -737,47 +797,50 @@ def _ablation_test_assert( perturbations_per_eval: Tuple[int, ...] = (1,), baselines: BaselineType = None, target: TargetType = 0, + test_enable_cross_tensor_attribution: List[bool] = [True, False], test_future: bool = False, **kwargs: Any, ) -> None: - for batch_size in perturbations_per_eval: - self.assertTrue(ablation_algo.multiplies_by_inputs) - if isinstance(ablation_algo, FeatureAblation) and test_future: - attributions = ablation_algo.attribute_future( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - **kwargs, - ).wait() - else: - attributions = ablation_algo.attribute( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - **kwargs, - ) - if isinstance(expected_ablation, tuple): - for i in range(len(expected_ablation)): - expected = expected_ablation[i] - if not isinstance(expected, torch.Tensor): - expected = torch.tensor(expected) - - self.assertEqual(attributions[i].shape, expected.shape) - self.assertEqual(attributions[i].dtype, expected.dtype) - assertTensorAlmostEqual(self, attributions[i], expected) - else: - if not isinstance(expected_ablation, torch.Tensor): - expected_ablation = torch.tensor(expected_ablation) - - self.assertEqual(attributions.shape, expected_ablation.shape) - self.assertEqual(attributions.dtype, expected_ablation.dtype) - assertTensorAlmostEqual(self, attributions, expected_ablation) + for enable_cross_tensor_attribution in test_enable_cross_tensor_attribution: + for batch_size in perturbations_per_eval: + self.assertTrue(ablation_algo.multiplies_by_inputs) + if isinstance(ablation_algo, FeatureAblation) and test_future: + attributions = ablation_algo.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + **kwargs, + ).wait() + else: + attributions = ablation_algo.attribute( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + **kwargs, + ) + if isinstance(expected_ablation, tuple): + for i in range(len(expected_ablation)): + expected = expected_ablation[i] + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected) + + self.assertEqual(attributions[i].shape, expected.shape) + self.assertEqual(attributions[i].dtype, expected.dtype) + assertTensorAlmostEqual(self, attributions[i], expected) + else: + if not isinstance(expected_ablation, torch.Tensor): + expected_ablation = torch.tensor(expected_ablation) + + self.assertEqual(attributions.shape, expected_ablation.shape) + self.assertEqual(attributions.dtype, expected_ablation.dtype) + assertTensorAlmostEqual(self, attributions, expected_ablation) if __name__ == "__main__": From 3b0c1807bed5adbac1b99e07d2ea7e13c51751df Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 20 Feb 2025 16:22:10 -0800 Subject: [PATCH 2/3] Support permuting across input tensors in FeaturePermutation (#1507) Summary: Most of the logic is in the parent class `FeatureAblation`, but to support feature grouping across input tensors we need to update the permutation utils too Reviewed By: cyrjano Differential Revision: D69867208 --- captum/attr/_core/feature_permutation.py | 63 ++++++++- tests/attr/test_feature_permutation.py | 161 ++++++++++++++++------- 2 files changed, 174 insertions(+), 50 deletions(-) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 79c519602f..1fc85d16fe 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Optional, Tuple, Union import torch -from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._core.feature_ablation import FeatureAblation from captum.log import log_usage from torch import Tensor @@ -25,6 +25,31 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: ) +def _permute_features_across_tensors( + inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...] +) -> Tuple[Tensor, ...]: + """ + Permutes features across multiple input tensors using the corresponding + feature masks. + """ + permuted_outputs = [] + for input_tensor, feature_mask in zip(inputs, feature_masks): + if not feature_mask.any(): + permuted_outputs.append(input_tensor) + continue + n = input_tensor.size(0) + assert n > 1, "cannot permute features with batch_size = 1" + perm = torch.randperm(n) + no_perm = torch.arange(n) + while (perm == no_perm).all(): + perm = torch.randperm(n) + permuted_x = ( + input_tensor[perm] * feature_mask.to(dtype=input_tensor.dtype) + ) + (input_tensor * feature_mask.bitwise_not().to(dtype=input_tensor.dtype)) + permuted_outputs.append(permuted_x) + return tuple(permuted_outputs) + + class FeaturePermutation(FeatureAblation): r""" A perturbation based approach to compute attribution, which @@ -55,7 +80,8 @@ class FeaturePermutation(FeatureAblation): of examples to compute attributions and cannot be performed on a single example. By default, each scalar value within - each input tensor is taken as a feature and shuffled independently. Passing + each input tensor is taken as a feature and shuffled independently, *unless* + attribute() is called with enable_cross_tensor_attribution=True. Passing a feature mask, allows grouping features to be shuffled together. Each input scalar in the group will be given the same attribution value equal to the change in target as a result of shuffling the entire feature @@ -76,6 +102,9 @@ def __init__( self, forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]], perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature, + perm_func_cross_tensor: Callable[ + [Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...] + ] = _permute_features_across_tensors, ) -> None: r""" Args: @@ -88,9 +117,14 @@ def __init__( which applies a random permutation, this argument only needs to be provided if a custom permutation behavior is desired. Default: `_permute_feature` + perm_func_cross_tensor (Callable, optional): Similar to perm_func, + except it can permute grouped features across multiple input + tensors, rather than taking each input tensor independently. + Default: `_permute_features_across_tensors` """ FeatureAblation.__init__(self, forward_func=forward_func) self.perm_func = perm_func + self.perm_func_cross_tensor = perm_func_cross_tensor # suppressing error caused by the child class not having a matching # signature to the parent @@ -103,6 +137,7 @@ def attribute( # type: ignore feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, + enable_cross_tensor_attribution: bool = False, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -176,14 +211,16 @@ def attribute( # type: ignore corresponding to the same feature should have the same value. Note that features within each input tensor are ablated independently (not across - tensors). + tensors), unless enable_cross_tensor_attribution is + True. The first dimension of each mask must be 1, as we require to have the same group of features for each input sample. If None, then a feature mask is constructed which assigns each scalar within a tensor as a separate feature, which - is permuted independently. + is permuted independently, unless + enable_cross_tensor_attribution is True. Default: None perturbations_per_eval (int, optional): Allows permutations of multiple features to be processed simultaneously @@ -202,6 +239,10 @@ def attribute( # type: ignore (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False + enable_cross_tensor_attribution (bool, optional): If True, then + features can be grouped across input tensors depending on + the values in the feature mask. + Default: False **kwargs (Any, optional): Any additional arguments used by child classes of :class:`.FeatureAblation` (such as :class:`.Occlusion`) to construct ablations. These @@ -273,6 +314,7 @@ def attribute( # type: ignore feature_mask=feature_mask, perturbations_per_eval=perturbations_per_eval, show_progress=show_progress, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, **kwargs, ) @@ -343,3 +385,16 @@ def _construct_ablated_input( ] ) return output, current_mask + + def _construct_ablated_input_across_tensors( + self, + inputs: Tuple[Tensor, ...], + input_mask: Tuple[Tensor, ...], + baselines: BaselineType, + feature_idx: int, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + feature_masks = tuple( + (mask == feature_idx).to(inputs[0].device) for mask in input_mask + ) + permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) + return permuted_outputs, feature_masks diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index e1795e0ebe..611b19238a 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -7,7 +7,7 @@ import torch from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation from captum.testing.helpers import BaseTest -from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic import assertTensorAlmostEqual, set_all_random_seeds from captum.testing.helpers.basic_models import BasicModelWithSparseInputs from torch import Tensor @@ -103,12 +103,14 @@ def forward_func(x: Tensor) -> Tensor: inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - - attribs = feature_importance.attribute(inp) - - self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) - assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") - self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) + for enable_cross_tensor_attribution in (True, False): + attribs = feature_importance.attribute( + inp, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + ) + self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) + assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") + self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) def test_single_input_with_future( self, @@ -166,12 +168,63 @@ def forward_func(*x: Tensor) -> Tensor: feature_mask = ( torch.arange(inp[0][0].numel()).view_as(inp[0][0]).unsqueeze(0), - torch.arange(inp[1][0].numel()).view_as(inp[1][0]).unsqueeze(0), + torch.arange(inp[0][0].numel(), inp[0][0].numel() + inp[1][0].numel()) + .view_as(inp[1][0]) + .unsqueeze(0), ) inp[1][:, :, 1] = 4 + for enable_cross_tensor_attribution in (True, False): + attribs = feature_importance.attribute( + inp, + feature_mask=feature_mask, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + ) + + self.assertTrue(isinstance(attribs, tuple)) + self.assertTrue(len(attribs) == 2) + + self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) + self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) + + self.assertTrue((attribs[1][:, :, 1] == 0).all()) + self.assertTrue((attribs[1][:, :, 2] == 0).all()) + + self.assertTrue((attribs[0] != 0).all()) + self.assertTrue((attribs[1][:, :, 0] != 0).all()) + + def test_multi_input_group_across_input_tensors( + self, + ) -> None: + batch_size = 20 + inp1_size = (5, 2) + inp2_size = (5, 3) + + labels: Tensor = torch.randn(batch_size) + + def forward_func(*x: Tensor) -> Tensor: + y = torch.zeros(x[0].shape[0:2]) + for xx in x: + y += xx[:, :, 0] * xx[:, :, 1] + y = y.sum(dim=-1) + + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + return torch.mean((y - labels) ** 2) - attribs = feature_importance.attribute(inp, feature_mask=feature_mask) + feature_importance = FeaturePermutation(forward_func=forward_func) + + inp = ( + torch.randn((batch_size,) + inp1_size), + torch.randn((batch_size,) + inp2_size), + ) + # Group all features together + feature_mask = tuple( + torch.zeros_like(inp_tensor[0]).unsqueeze(0) for inp_tensor in inp + ) + attribs = feature_importance.attribute( + inp, feature_mask=feature_mask, enable_cross_tensor_attribution=True + ) self.assertTrue(isinstance(attribs, tuple)) self.assertTrue(len(attribs) == 2) @@ -179,11 +232,11 @@ def forward_func(*x: Tensor) -> Tensor: self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) - self.assertTrue((attribs[1][:, :, 1] == 0).all()) - self.assertTrue((attribs[1][:, :, 2] == 0).all()) - - self.assertTrue((attribs[0] != 0).all()) - self.assertTrue((attribs[1][:, :, 0] != 0).all()) + first_elem_first_attrib = attribs[0].flatten()[0] + first_elem_second_attrib = attribs[1].flatten()[0] + self.assertTrue(torch.all(attribs[0] == first_elem_first_attrib)) + self.assertTrue(torch.all(attribs[0] == first_elem_second_attrib)) + self.assertEqual(first_elem_first_attrib, first_elem_second_attrib) def test_multi_input_with_future( self, @@ -324,26 +377,30 @@ def forward_func(x: Tensor) -> Tensor: torch.tensor([[0, 1, 2, 3]]), torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]), ] + for enable_cross_tensor_attribution in (True, False): + for mask in masks: - for mask in masks: - - attribs = feature_importance.attribute(inp, feature_mask=mask) - self.assertTrue(attribs is not None) - self.assertTrue(attribs.shape == inp.shape) - - fm = mask.expand_as(inp[0]) - - features = set(mask.flatten()) - for feature in features: - m = (fm == feature).bool() - attribs_for_feature = attribs[:, m] - assertTensorAlmostEqual( - self, - attribs_for_feature[0], - -attribs_for_feature[1], - delta=0.05, - mode="max", + attribs = feature_importance.attribute( + inp, + feature_mask=mask, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, ) + self.assertTrue(attribs is not None) + self.assertTrue(attribs.shape == inp.shape) + + fm = mask.expand_as(inp[0]) + + features = set(mask.flatten()) + for feature in features: + m = (fm == feature).bool() + attribs_for_feature = attribs[:, m] + assertTensorAlmostEqual( + self, + attribs_for_feature[0], + -attribs_for_feature[1], + delta=0.05, + mode="max", + ) def test_broadcastable_masks_with_future( self, @@ -399,9 +456,13 @@ def test_empty_sparse_features(self) -> None: # test empty sparse tensor feature_importance = FeaturePermutation(model) - attr1, attr2 = feature_importance.attribute((inp1, inp2)) - self.assertEqual(attr1.shape, (1, 3)) - self.assertEqual(attr2.shape, (1,)) + for enable_cross_tensor_attribution in (True, False): + attr1, attr2 = feature_importance.attribute( + (inp1, inp2), + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + ) + self.assertEqual(attr1.shape, (1, 3)) + self.assertEqual(attr2.shape, (1,)) def test_sparse_features(self) -> None: model = BasicModelWithSparseInputs() @@ -410,14 +471,22 @@ def test_sparse_features(self) -> None: inp2 = torch.tensor([1, 7, 2, 4, 5, 3, 6]) feature_importance = FeaturePermutation(model) - total_attr1, total_attr2 = feature_importance.attribute((inp1, inp2)) - - for _ in range(50): - attr1, attr2 = feature_importance.attribute((inp1, inp2)) - total_attr1 += attr1 - total_attr2 += attr2 - total_attr1 /= 50 - total_attr2 /= 50 - self.assertEqual(total_attr2.shape, (1,)) - assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) - assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2) + + for enable_cross_tensor_attribution in [True, False]: + set_all_random_seeds(1234) + total_attr1, total_attr2 = feature_importance.attribute( + (inp1, inp2), + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + ) + for _ in range(50): + attr1, attr2 = feature_importance.attribute( + (inp1, inp2), + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + ) + total_attr1 += attr1 + total_attr2 += attr2 + total_attr1 /= 50 + total_attr2 /= 50 + self.assertEqual(total_attr2.shape, (1,)) + assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) + assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2) From 7fc86ea3042f7d36e34a4e77582749a604a29edf Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 20 Feb 2025 16:22:10 -0800 Subject: [PATCH 3/3] Add enable_cross_tensor_attribution tests to test_config Summary: TSIA Differential Revision: D69957243 --- captum/testing/attr/helpers/test_config.py | 82 ++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/captum/testing/attr/helpers/test_config.py b/captum/testing/attr/helpers/test_config.py index effe71a69a..d97da0f169 100644 --- a/captum/testing/attr/helpers/test_config.py +++ b/captum/testing/attr/helpers/test_config.py @@ -112,6 +112,19 @@ "model": BasicModel_MultiLayer(), "attribute_args": {"inputs": torch.randn(4, 3), "target": 1}, }, + { + "name": "basic_single_target_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + FeaturePermutation, + ], + "model": BasicModel_MultiLayer(), + "attribute_args": { + "inputs": torch.randn(4, 3), + "target": 1, + "enable_cross_tensor_attribution": True, + }, + }, { "name": "basic_multi_input", "algorithms": [ @@ -179,6 +192,21 @@ }, "dp_delta": 0.0005, }, + { + "name": "basic_multi_input_multi_target_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + FeaturePermutation, + ], + "model": BasicModel_MultiLayer_MultiInput(), + "attribute_args": { + "inputs": (10 * torch.randn(6, 3), 5 * torch.randn(6, 3)), + "additional_forward_args": (2 * torch.randn(6, 3), 5), + "target": [0, 1, 1, 0, 0, 1], + "enable_cross_tensor_attribution": True, + }, + "dp_delta": 0.0005, + }, { "name": "basic_multiple_tuple_target", "algorithms": [ @@ -202,6 +230,20 @@ "additional_forward_args": (None, True), }, }, + { + "name": "basic_multiple_tuple_target_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + FeaturePermutation, + ], + "model": BasicModel_MultiLayer(), + "attribute_args": { + "inputs": torch.randn(4, 3), + "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], + "additional_forward_args": (None, True), + "enable_cross_tensor_attribution": True, + }, + }, { "name": "basic_tensor_single_target", "algorithms": [ @@ -243,6 +285,19 @@ "target": torch.tensor([1, 1, 0, 0]), }, }, + { + "name": "basic_tensor_multi_target_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + FeaturePermutation, + ], + "model": BasicModel_MultiLayer(), + "attribute_args": { + "inputs": torch.randn(4, 3), + "target": torch.tensor([1, 1, 0, 0]), + "enable_cross_tensor_attribution": True, + }, + }, # Primary Configs with Baselines { "name": "basic_multiple_tuple_target_with_baselines", @@ -262,6 +317,20 @@ "additional_forward_args": (None, True), }, }, + { + "name": "basic_multiple_tuple_target_with_baselines_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + ], + "model": BasicModel_MultiLayer(), + "attribute_args": { + "inputs": torch.randn(4, 3), + "baselines": 0.5 * torch.randn(4, 3), + "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], + "additional_forward_args": (None, True), + "enable_cross_tensor_attribution": True, + }, + }, { "name": "basic_tensor_single_target_with_baselines", "algorithms": [ @@ -279,6 +348,19 @@ "target": torch.tensor([0]), }, }, + { + "name": "basic_tensor_single_target_with_baselines_cross_tensor_attributions", + "algorithms": [ + FeatureAblation, + ], + "model": BasicModel_MultiLayer(), + "attribute_args": { + "inputs": torch.randn(4, 3), + "baselines": 0.5 * torch.randn(4, 3), + "target": torch.tensor([0]), + "enable_cross_tensor_attribution": True, + }, + }, # Primary Configs with Internal Batching { "name": "basic_multiple_tuple_target_with_internal_batching",