From 09f9628910ac758c1f6bd173902c4837847371a8 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 18 Feb 2025 14:02:44 -0800 Subject: [PATCH] Support feature grouping across input tensors for FeatureAblation (#1497) Summary: Basic support; doesn't currently support multiple perturbations per eval Reviewed By: cyrjano Differential Revision: D69531512 --- captum/attr/_core/feature_ablation.py | 351 +++++++++++++++++++++----- setup.py | 2 +- tests/attr/test_feature_ablation.py | 151 +++++++---- 3 files changed, 395 insertions(+), 109 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index e5e60bb465..5375dbb638 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,7 @@ # pyre-strict import math -from typing import Any, Callable, cast, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union import torch from captum._utils.common import ( @@ -16,7 +16,7 @@ _run_forward, ) from captum._utils.exceptions import FeatureAblationFutureError -from captum._utils.progress import progress +from captum._utils.progress import progress, SimpleProgress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.common import _format_input_baseline @@ -24,6 +24,13 @@ from torch import dtype, Tensor from torch.futures import collect_all, Future +try: + from tqdm.auto import tqdm +except ImportError: + tqdm = None + +IterableType = TypeVar("IterableType") + class FeatureAblation(PerturbationAttribution): r""" @@ -79,6 +86,7 @@ def attribute( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, + enable_cross_tensor_attribution: bool = False, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -171,7 +179,8 @@ def attribute( - 1, and indices corresponding to the same feature should have the same value. Note that features within each input tensor are ablated - independently (not across tensors). + independently (not across tensors), unless + enable_cross_tensor_attribution is True. If the forward function returns a single scalar per batch, we enforce that the first dimension of each mask must be 1, since attributions are returned batch-wise rather than per @@ -179,7 +188,7 @@ def attribute( same features (indices) in each input example. If None, then a feature mask is constructed which assigns each scalar within a tensor as a separate feature, which - is ablated independently. + is ablated independently by default. Default: None perturbations_per_eval (int, optional): Allows ablation of multiple features to be processed simultaneously in one call to @@ -202,6 +211,10 @@ def attribute( (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False + enable_cross_tensor_attribution (bool, optional): If True, features + IDs in feature_mask are global IDs across input tensors, + and are ablated together. + Default: False **kwargs (Any, optional): Any additional arguments used by child classes of FeatureAblation (such as Occlusion) to construct ablations. These arguments are ignored when using @@ -274,10 +287,12 @@ def attribute( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Perturbations per evaluation must be an integer and at least 1." with torch.no_grad(): + attr_progress = None if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, + enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -291,7 +306,7 @@ def attribute( target, formatted_additional_forward_args, ) - if show_progress: + if attr_progress is not None: attr_progress.update() total_attrib: List[Tensor] = [] @@ -318,69 +333,239 @@ def attribute( formatted_inputs, ) - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - # Skip any empty input tensors - if torch.numel(formatted_inputs[i]) == 0: - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, + if enable_cross_tensor_attribution: + total_attrib, weights = self._attribute_with_cross_tensor_feature_masks( + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + attr_progress, + flattened_initial_eval, + n_outputs, + total_attrib, + weights, + attrib_type, + **kwargs, + ) + else: + total_attrib, weights = self._attribute_with_independent_feature_masks( formatted_inputs, formatted_additional_forward_args, target, baselines, formatted_feature_mask, + num_examples, perturbations_per_eval, + attr_progress, + initial_eval, + flattened_initial_eval, + n_outputs, + total_attrib, + weights, + attrib_type, **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) + ) - if show_progress: - attr_progress.update() + if attr_progress is not None: + attr_progress.close() - assert not isinstance(modified_eval, torch.Future), ( - "when use_futures is True, modified_eval should have " - f"non-Future type rather than {type(modified_eval)}" - ) - total_attrib, weights = self._process_ablated_out( - modified_eval, - current_inputs, - current_mask, - perturbations_per_eval, - num_examples, - initial_eval, - flattened_initial_eval, - formatted_inputs, - n_outputs, - total_attrib, - weights, - i, - attrib_type, - ) + # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: + # [Tensor, typing.Tuple[Tensor, ...]]]` + # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. + return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long - if show_progress: - attr_progress.close() + def _attribute_with_independent_feature_masks( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + num_examples: int, + perturbations_per_eval: int, + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + initial_eval: Tensor, + flattened_initial_eval: Tensor, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + **kwargs: Any, + ) -> Tuple[List[Tensor], List[Tensor]]: + # Iterate through each feature tensor for ablation + for i in range(len(formatted_inputs)): + # Skip any empty input tensors + if torch.numel(formatted_inputs[i]) == 0: + continue + + for ( + current_inputs, + current_add_args, + current_target, + current_mask, + ) in self._ith_input_ablation_generator( + i, + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + perturbations_per_eval, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + + if attr_progress is not None: + attr_progress.update() + + assert not isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"non-Future type rather than {type(modified_eval)}" + ) + total_attrib, weights = self._process_ablated_out( + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + formatted_inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) + return total_attrib, weights + + def _attribute_with_cross_tensor_feature_masks( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + flattened_initial_eval: Tensor, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + **kwargs: Any, + ) -> Tuple[List[Tensor], List[Tensor]]: + for ( + current_inputs, + current_mask, + ) in self._ablation_generator( + formatted_inputs, + baselines, + formatted_feature_mask, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval = _run_forward( + self.forward_func, + current_inputs, + target, + formatted_additional_forward_args, + ) - # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: - # [Tensor, typing.Tuple[Tensor, ...]]]` - # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. - return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long + if attr_progress is not None: + attr_progress.update() + + assert not isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"non-Future type rather than {type(modified_eval)}" + ) + + total_attrib, weights = self._process_ablated_out_full( + modified_eval, + current_mask, + flattened_initial_eval, + formatted_inputs, + n_outputs, + total_attrib, + weights, + attrib_type, + ) + return total_attrib, weights + + def _ablation_generator( + self, + inputs: Tuple[Tensor, ...], + baselines: BaselineType, + input_mask: Tuple[Tensor, ...], + **kwargs: Any, + ) -> Generator[ + Tuple[ + Tuple[Tensor, ...], + Tuple[Tensor, ...], + ], + None, + None, + ]: + unique_feature_ids = torch.unique( + torch.cat([mask.flatten() for mask in input_mask]) + ).tolist() + + if isinstance(baselines, torch.Tensor): + baselines = baselines.reshape((1,) + tuple(baselines.shape)) + + # Process one feature per time, rather than processing every input tensor + for feature_idx in unique_feature_ids: + ablated_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + inputs, input_mask, baselines, feature_idx + ) + ) + yield ablated_inputs, current_masks + + def _construct_ablated_input_across_tensors( + self, + inputs: Tuple[Tensor, ...], + input_mask: Tuple[Tensor, ...], + baselines: BaselineType, + feature_idx: int, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + + ablated_inputs = [] + current_masks = [] + for i, input_tensor in enumerate(inputs): + mask = input_mask[i] + tensor_mask = mask == feature_idx + if not tensor_mask.any(): + ablated_inputs.append(input_tensor) + current_masks.append(torch.zeros_like(tensor_mask)) + continue + tensor_mask = tensor_mask.to(input_tensor.device).long() + baseline = baselines[i] if isinstance(baselines, tuple) else baselines + if isinstance(baseline, torch.Tensor): + baseline = baseline.reshape( + (1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape) + ) + assert baseline is not None, "baseline must be provided" + ablated_input = ( + input_tensor * (1 - tensor_mask).to(input_tensor.dtype) + ) + (baseline * tensor_mask.to(input_tensor.dtype)) + ablated_inputs.append(ablated_input) + current_masks.append(tensor_mask) + return tuple(ablated_inputs), tuple(current_masks) def _initial_eval_to_processed_initial_eval_fut( self, initial_eval: Future[Tensor], formatted_inputs: Tuple[Tensor, ...] @@ -572,6 +757,7 @@ def _attribute_progress_setup( self, formatted_inputs: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], + enable_cross_tensor_attribution: bool, perturbations_per_eval: int, **kwargs: Any, ): @@ -579,9 +765,13 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - sum(math.ceil(count / perturbations_per_eval) for count in feature_counts) - + 1 - ) # add 1 for the initial eval + int(sum(feature_counts)) + if enable_cross_tensor_attribution + else sum( + math.ceil(count / perturbations_per_eval) for count in feature_counts + ) + ) + total_forwards += 1 # add 1 for the initial eval attr_progress = progress( desc=f"{self.get_name()} attribution", total=total_forwards ) @@ -808,10 +998,7 @@ def _construct_ablated_input( current_mask = current_mask.to(expanded_input.device) assert baseline is not None, "baseline must be provided" ablated_tensor = ( - expanded_input - * (1 - current_mask).to(expanded_input.dtype) - # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, - # Tensor]` and `Tensor`. + expanded_input * (1 - current_mask).to(expanded_input.dtype) ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask @@ -983,6 +1170,42 @@ def _process_ablated_out( total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0) return total_attrib, weights + def _process_ablated_out_full( + self, + modified_eval: Tensor, + current_mask: Tuple[Tensor, ...], + flattened_initial_eval: Tensor, + inputs: TensorOrTupleOfTensorsGeneric, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + attrib_type: dtype, + ) -> Tuple[List[Tensor], List[Tensor]]: + modified_eval = self._parse_forward_out(modified_eval) + + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = flattened_initial_eval - modified_eval + eval_diff_shape = eval_diff.shape + + # append the shape of one input example + # to make it broadcastable to mask + + if self.use_weights: + for weight, mask in zip(weights, current_mask): + weight += mask.float().sum(dim=0) + for i, mask in enumerate(current_mask): + if inputs[i].numel() == 0: + continue + eval_diff = eval_diff.reshape( + eval_diff_shape + (inputs[i].dim() - 1) * (1,) + ) + eval_diff = eval_diff.to(total_attrib[i].device) + total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights + def _fut_tuple_to_accumulate_fut_list( self, total_attrib: List[Tensor], diff --git a/setup.py b/setup.py index 5effc53d72..f6e4558d6c 100755 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ def report(*args): + [ "black", "flake8", - "sphinx", + "sphinx<8.2.0", "sphinx-autodoc-typehints", "sphinxcontrib-katex", "mypy>=0.760", diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 3646bd6c58..c8f9802d6a 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -170,8 +170,8 @@ def test_multi_input_ablation_with_mask(self) -> None: inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) expected = ( [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], @@ -207,8 +207,8 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_multi_input_ablation_with_mask_nt(self) -> None: - ablation_algo = NoiseTunnel(FeatureAblation(BasicModel_MultiLayer_MultiInput())) + def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) @@ -220,6 +220,66 @@ def test_multi_input_ablation_with_mask_nt(self) -> None: [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], ) + expected_cross_tensor = ( + [[1092.0, 1092.0, 1092.0], [260.0, 600.0, 260.0]], + [[80.0, 1092.0, 160.0], [260.0, 600.0, 0.0]], + [[80.0, 1092.0, 160.0], [260.0, 260.0, 260.0]], + ) + for test_enable_cross_tensor_attribution, expected_out in [ + (True, expected_cross_tensor), + (False, expected), + ]: + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_out, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + test_enable_cross_tensor_attribution=[ + test_enable_cross_tensor_attribution + ], + ) + + expected_with_baseline = ( + [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], + [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], + [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + ) + expected_cross_tensor_with_baseline = ( + [[1040.0, 1040.0, 1040.0], [184.0, 580.0, 184.0]], + [[52.0, 1040.0, 132.0], [184.0, 580.0, -12.0]], + [[52.0, 1040.0, 132.0], [184.0, 184.0, 184.0]], + ) + for test_enable_cross_tensor_attribution, expected_out in [ + (True, expected_cross_tensor_with_baseline), + (False, expected_with_baseline), + ]: + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_out, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + test_enable_cross_tensor_attribution=[ + test_enable_cross_tensor_attribution + ], + ) + + def test_multi_input_ablation_with_mask_nt(self) -> None: + ablation_algo = NoiseTunnel(FeatureAblation(BasicModel_MultiLayer_MultiInput())) + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) + expected = ( + [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], + [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], + [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], + ) self._ablation_test_assert( ablation_algo, (inp1, inp2, inp3), @@ -694,8 +754,8 @@ def _multi_input_batch_scalar_ablation_assert( inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2]]) + mask2 = torch.tensor([[0, 3, 2]]) + mask3 = torch.tensor([[4, 5, 6]]) expected = ( torch.tensor([[1784, 1784, 1784]], dtype=dtype), torch.tensor([[160, 1200, 240]], dtype=dtype), @@ -737,47 +797,50 @@ def _ablation_test_assert( perturbations_per_eval: Tuple[int, ...] = (1,), baselines: BaselineType = None, target: TargetType = 0, + test_enable_cross_tensor_attribution: List[bool] = [True, False], test_future: bool = False, **kwargs: Any, ) -> None: - for batch_size in perturbations_per_eval: - self.assertTrue(ablation_algo.multiplies_by_inputs) - if isinstance(ablation_algo, FeatureAblation) and test_future: - attributions = ablation_algo.attribute_future( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - **kwargs, - ).wait() - else: - attributions = ablation_algo.attribute( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - **kwargs, - ) - if isinstance(expected_ablation, tuple): - for i in range(len(expected_ablation)): - expected = expected_ablation[i] - if not isinstance(expected, torch.Tensor): - expected = torch.tensor(expected) - - self.assertEqual(attributions[i].shape, expected.shape) - self.assertEqual(attributions[i].dtype, expected.dtype) - assertTensorAlmostEqual(self, attributions[i], expected) - else: - if not isinstance(expected_ablation, torch.Tensor): - expected_ablation = torch.tensor(expected_ablation) - - self.assertEqual(attributions.shape, expected_ablation.shape) - self.assertEqual(attributions.dtype, expected_ablation.dtype) - assertTensorAlmostEqual(self, attributions, expected_ablation) + for enable_cross_tensor_attribution in test_enable_cross_tensor_attribution: + for batch_size in perturbations_per_eval: + self.assertTrue(ablation_algo.multiplies_by_inputs) + if isinstance(ablation_algo, FeatureAblation) and test_future: + attributions = ablation_algo.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + **kwargs, + ).wait() + else: + attributions = ablation_algo.attribute( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + enable_cross_tensor_attribution=enable_cross_tensor_attribution, + **kwargs, + ) + if isinstance(expected_ablation, tuple): + for i in range(len(expected_ablation)): + expected = expected_ablation[i] + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected) + + self.assertEqual(attributions[i].shape, expected.shape) + self.assertEqual(attributions[i].dtype, expected.dtype) + assertTensorAlmostEqual(self, attributions[i], expected) + else: + if not isinstance(expected_ablation, torch.Tensor): + expected_ablation = torch.tensor(expected_ablation) + + self.assertEqual(attributions.shape, expected_ablation.shape) + self.assertEqual(attributions.dtype, expected_ablation.dtype) + assertTensorAlmostEqual(self, attributions, expected_ablation) if __name__ == "__main__":