From cac530441d153c454edde1086c28a58dca72b508 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 25 Mar 2025 10:58:19 -0700 Subject: [PATCH] Support multiple perturbations per eval when masking across tensors (#1530) Summary: This was supported in the old path (when constructing ablated inputs over each input tensor individually) to improve compute efficiency by optionally passing in multiple perturbed inputs to the model fwd function. Reviewed By: craymichael Differential Revision: D71435704 --- captum/attr/_core/feature_ablation.py | 208 ++++++++++++++++------- captum/attr/_core/feature_permutation.py | 46 +++-- tests/attr/test_feature_ablation.py | 57 +++++++ 3 files changed, 240 insertions(+), 71 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index d7f2570c9b..c6a47417e4 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -353,10 +353,12 @@ def attribute( formatted_feature_mask, attr_progress, flattened_initial_eval, + initial_eval, n_outputs, total_attrib, weights, attrib_type, + perturbations_per_eval, **kwargs, ) else: @@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks( formatted_feature_mask: Tuple[Tensor, ...], attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], flattened_initial_eval: Tensor, + initial_eval: Tensor, n_outputs: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: feature_idx_to_tensor_idx: Dict[int, List[int]] = {} @@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks( if feature_idx.item() not in feature_idx_to_tensor_idx: feature_idx_to_tensor_idx[feature_idx.item()] = [] feature_idx_to_tensor_idx[feature_idx.item()].append(i) + all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) + + additional_args_repeated: object + if perturbations_per_eval > 1: + # Repeat features and additional args for batch size. + all_features_repeated = tuple( + torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0) + for j in range(len(formatted_inputs)) + ) + additional_args_repeated = ( + _expand_additional_forward_args( + formatted_additional_forward_args, perturbations_per_eval + ) + if formatted_additional_forward_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + else: + all_features_repeated = formatted_inputs + additional_args_repeated = formatted_additional_forward_args + target_repeated = target + num_examples = formatted_inputs[0].shape[0] + + current_additional_args: object + if isinstance(baselines, tuple): + reshaped = False + reshaped_baselines: list[Union[Tensor, int, float]] = [] + for baseline in baselines: + if isinstance(baseline, Tensor): + reshaped = True + reshaped_baselines.append( + baseline.reshape((1,) + tuple(baseline.shape)) + ) + else: + reshaped_baselines.append(baseline) + baselines = tuple(reshaped_baselines) if reshaped else baselines + for i in range(0, len(all_feature_idxs), perturbations_per_eval): + current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval] + current_num_ablated_features = min( + perturbations_per_eval, len(current_feature_idxs) + ) + + # Store appropriate inputs and additional args based on batch size. + if current_num_ablated_features != perturbations_per_eval: + current_additional_args = ( + _expand_additional_forward_args( + formatted_additional_forward_args, current_num_ablated_features + ) + if formatted_additional_forward_args is not None + else None + ) + current_target = _expand_target(target, current_num_ablated_features) + expanded_inputs = tuple( + feature_repeated[0 : current_num_ablated_features * num_examples] + for feature_repeated in all_features_repeated + ) + else: + current_additional_args = additional_args_repeated + current_target = target_repeated + expanded_inputs = all_features_repeated + + current_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + expanded_inputs, + formatted_feature_mask, + baselines, + current_feature_idxs, + feature_idx_to_tensor_idx, + current_num_ablated_features, + ) + ) - for ( - current_inputs, - current_mask, - ) in self._ablation_generator( - formatted_inputs, - baselines, - formatted_feature_mask, - feature_idx_to_tensor_idx, - **kwargs, - ): # modified_eval has (n_feature_perturbed * n_outputs) elements # shape: # agg mode: (*initial_eval.shape) @@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks( modified_eval = _run_forward( self.forward_func, current_inputs, - target, - formatted_additional_forward_args, + current_target, + current_additional_args, ) if attr_progress is not None: @@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks( total_attrib, weights = self._process_ablated_out_full( modified_eval, - current_mask, + current_masks, flattened_initial_eval, - formatted_inputs, + initial_eval, + current_inputs, n_outputs, + num_examples, total_attrib, weights, attrib_type, + perturbations_per_eval, ) return total_attrib, weights - def _ablation_generator( - self, - inputs: Tuple[Tensor, ...], - baselines: BaselineType, - input_mask: Tuple[Tensor, ...], - feature_idx_to_tensor_idx: Dict[int, List[int]], - **kwargs: Any, - ) -> Generator[ - Tuple[ - Tuple[Tensor, ...], - Tuple[Optional[Tensor], ...], - ], - None, - None, - ]: - if isinstance(baselines, torch.Tensor): - baselines = baselines.reshape((1,) + tuple(baselines.shape)) - - # Process one feature per time, rather than processing every input tensor - for feature_idx in feature_idx_to_tensor_idx.keys(): - ablated_inputs, current_masks = ( - self._construct_ablated_input_across_tensors( - inputs, - input_mask, - baselines, - feature_idx, - feature_idx_to_tensor_idx[feature_idx], - ) - ) - yield ablated_inputs, current_masks - def _construct_ablated_input_across_tensors( self, inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: - ablated_inputs = [] current_masks: List[Optional[Tensor]] = [] + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: ablated_inputs.append(input_tensor) current_masks.append(None) continue - tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + tensor_mask = [] + ablated_input = input_tensor.clone() baseline = baselines[i] if isinstance(baselines, tuple) else baselines - if isinstance(baseline, torch.Tensor): - baseline = baseline.reshape( - (1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape) + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features ) - assert baseline is not None, "baseline must be provided" - ablated_input = ( - input_tensor * (1 - tensor_mask).to(input_tensor.dtype) - ) + (baseline * tensor_mask.to(input_tensor.dtype)) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + + assert baseline is not None, "baseline must be provided" + ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * ( + 1 - mask + ) + (baseline * mask.to(input_tensor.dtype)) + current_masks.append(torch.stack(tensor_mask, dim=0)) ablated_inputs.append(ablated_input) - current_masks.append(tensor_mask) + return tuple(ablated_inputs), tuple(current_masks) def _initial_eval_to_processed_initial_eval_fut( @@ -784,7 +839,7 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - int(sum(feature_counts)) + math.ceil(int(sum(feature_counts)) / perturbations_per_eval) if enable_cross_tensor_attribution else sum( math.ceil(count / perturbations_per_eval) for count in feature_counts @@ -1194,13 +1249,46 @@ def _process_ablated_out_full( modified_eval: Tensor, current_mask: Tuple[Optional[Tensor], ...], flattened_initial_eval: Tensor, + initial_eval: Tensor, inputs: TensorOrTupleOfTensorsGeneric, n_outputs: int, + num_examples: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, ) -> Tuple[List[Tensor], List[Tensor]]: modified_eval = self._parse_forward_out(modified_eval) + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + current_batch_size = inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + if perturbations_per_eval > 1 and not self._is_output_shape_valid: + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + + assert ( + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True # reshape the leading dim for n_feature_perturbed # flatten each feature's eval outputs into 1D of (n_outputs) @@ -1209,9 +1297,6 @@ def _process_ablated_out_full( eval_diff = flattened_initial_eval - modified_eval eval_diff_shape = eval_diff.shape - # append the shape of one input example - # to make it broadcastable to mask - if self.use_weights: for weight, mask in zip(weights, current_mask): if mask is not None: @@ -1224,6 +1309,7 @@ def _process_ablated_out_full( ) eval_diff = eval_diff.to(total_attrib[i].device) total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights def _fut_tuple_to_accumulate_fut_list( diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 3657c00fc2..0d64f1d8b0 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors( inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: current_masks: List[Optional[Tensor]] = [] - for i, mask in enumerate(input_mask): - if i in tensor_idxs: - current_masks.append((mask == feature_idx).to(inputs[0].device)) - else: + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + permuted_inputs = [] + for i, input_tensor in enumerate(inputs): + if i not in tensor_idxs: current_masks.append(None) - feature_masks = tuple(current_masks) - permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) - return permuted_outputs, feature_masks + permuted_inputs.append(input_tensor) + continue + tensor_mask = [] + permuted_input = input_tensor.clone() + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features + ) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + permuted_input[start_idx:end_idx] = self.perm_func( + input_tensor[start_idx:end_idx], mask + ) + current_masks.append(torch.stack(tensor_mask, dim=0)) + permuted_inputs.append(permuted_input) + + return tuple(permuted_inputs), tuple(current_masks) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index c8f9802d6a..5c3101ad01 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -164,6 +164,19 @@ def test_multi_sample_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_sample_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer()) + ablation_algo.use_weights = True + inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) + mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) + self._ablation_test_assert( + ablation_algo, + inp, + [[41.0, 41.0, 12.0], [280.0, 280.0, 120.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) @@ -207,6 +220,50 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_input_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) + ablation_algo.use_weights = True + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) + expected = ( + [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], + [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], + [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2), + expected[0:1], + additional_input=(inp3, 1), + feature_mask=(mask1, mask2), + perturbations_per_eval=(1, 2, 3), + ) + expected_with_baseline = ( + [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], + [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], + [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_with_baseline, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])