Skip to content

Commit 31fa27b

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
perturb FA
Differential Revision: D71435704
1 parent b899732 commit 31fa27b

File tree

2 files changed

+180
-36
lines changed

2 files changed

+180
-36
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 123 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,12 @@ def attribute(
353353
formatted_feature_mask,
354354
attr_progress,
355355
flattened_initial_eval,
356+
initial_eval,
356357
n_outputs,
357358
total_attrib,
358359
weights,
359360
attrib_type,
361+
perturbations_per_eval,
360362
**kwargs,
361363
)
362364
else:
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470472
formatted_feature_mask: Tuple[Tensor, ...],
471473
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
472474
flattened_initial_eval: Tensor,
475+
initial_eval: Tensor,
473476
n_outputs: int,
474477
total_attrib: List[Tensor],
475478
weights: List[Tensor],
476479
attrib_type: dtype,
480+
perturbations_per_eval: int,
477481
**kwargs: Any,
478482
) -> Tuple[List[Tensor], List[Tensor]]:
479483
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
@@ -482,17 +486,66 @@ def _attribute_with_cross_tensor_feature_masks(
482486
if feature_idx.item() not in feature_idx_to_tensor_idx:
483487
feature_idx_to_tensor_idx[feature_idx.item()] = []
484488
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
489+
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
490+
additional_args_repeated: object
491+
if perturbations_per_eval > 1:
492+
additional_args_repeated = (
493+
_expand_additional_forward_args(
494+
formatted_additional_forward_args, perturbations_per_eval
495+
)
496+
if formatted_additional_forward_args is not None
497+
else None
498+
)
499+
target_repeated = _expand_target(target, perturbations_per_eval)
500+
else:
501+
additional_args_repeated = formatted_additional_forward_args
502+
target_repeated = target
503+
num_examples = formatted_inputs[0].shape[0]
504+
505+
current_additional_args: object
506+
# Process one feature per time, rather than processing every input tensor
507+
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
508+
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
509+
current_num_ablated_features = min(
510+
perturbations_per_eval, len(current_feature_idxs)
511+
)
512+
513+
# Store appropriate inputs and additional args based on batch size.
514+
if current_num_ablated_features != perturbations_per_eval:
515+
current_additional_args = (
516+
_expand_additional_forward_args(
517+
formatted_additional_forward_args, current_num_ablated_features
518+
)
519+
if formatted_additional_forward_args is not None
520+
else None
521+
)
522+
current_target = _expand_target(target, current_num_ablated_features)
523+
else:
524+
current_additional_args = additional_args_repeated
525+
current_target = target_repeated
526+
527+
current_inputs = ()
528+
current_masks = []
529+
for (
530+
single_perturb_input,
531+
single_perturb_masks,
532+
) in self._ablation_generator(
533+
formatted_inputs,
534+
baselines,
535+
formatted_feature_mask,
536+
current_feature_idxs,
537+
feature_idx_to_tensor_idx,
538+
**kwargs,
539+
):
540+
if len(current_inputs) == 0:
541+
current_inputs = single_perturb_input
542+
else:
543+
current_inputs = tuple(
544+
torch.cat((current_inputs[j], single_perturb_input[j]), dim=0)
545+
for j in range(len(current_inputs))
546+
)
547+
current_masks.append(list(single_perturb_masks))
485548

486-
for (
487-
current_inputs,
488-
current_mask,
489-
) in self._ablation_generator(
490-
formatted_inputs,
491-
baselines,
492-
formatted_feature_mask,
493-
feature_idx_to_tensor_idx,
494-
**kwargs,
495-
):
496549
# modified_eval has (n_feature_perturbed * n_outputs) elements
497550
# shape:
498551
# agg mode: (*initial_eval.shape)
@@ -501,8 +554,8 @@ def _attribute_with_cross_tensor_feature_masks(
501554
modified_eval = _run_forward(
502555
self.forward_func,
503556
current_inputs,
504-
target,
505-
formatted_additional_forward_args,
557+
current_target,
558+
current_additional_args,
506559
)
507560

508561
if attr_progress is not None:
@@ -515,13 +568,16 @@ def _attribute_with_cross_tensor_feature_masks(
515568

516569
total_attrib, weights = self._process_ablated_out_full(
517570
modified_eval,
518-
current_mask,
571+
current_masks,
519572
flattened_initial_eval,
520-
formatted_inputs,
573+
initial_eval,
574+
current_inputs,
521575
n_outputs,
576+
num_examples,
522577
total_attrib,
523578
weights,
524579
attrib_type,
580+
perturbations_per_eval,
525581
)
526582
return total_attrib, weights
527583

@@ -530,6 +586,7 @@ def _ablation_generator(
530586
inputs: Tuple[Tensor, ...],
531587
baselines: BaselineType,
532588
input_mask: Tuple[Tensor, ...],
589+
feature_idxs: List[int],
533590
feature_idx_to_tensor_idx: Dict[int, List[int]],
534591
**kwargs: Any,
535592
) -> Generator[
@@ -540,11 +597,8 @@ def _ablation_generator(
540597
None,
541598
None,
542599
]:
543-
if isinstance(baselines, torch.Tensor):
544-
baselines = baselines.reshape((1,) + tuple(baselines.shape))
545-
546600
# Process one feature per time, rather than processing every input tensor
547-
for feature_idx in feature_idx_to_tensor_idx.keys():
601+
for feature_idx in feature_idxs:
548602
ablated_inputs, current_masks = (
549603
self._construct_ablated_input_across_tensors(
550604
inputs,
@@ -784,7 +838,7 @@ def _attribute_progress_setup(
784838
formatted_inputs, feature_mask, **kwargs
785839
)
786840
total_forwards = (
787-
int(sum(feature_counts))
841+
math.ceil(int(sum(feature_counts)) / perturbations_per_eval)
788842
if enable_cross_tensor_attribution
789843
else sum(
790844
math.ceil(count / perturbations_per_eval) for count in feature_counts
@@ -1187,43 +1241,76 @@ def _process_ablated_out(
11871241
weights[i] += current_mask.float().sum(dim=0)
11881242

11891243
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0)
1244+
print(i, weights)
11901245
return total_attrib, weights
11911246

11921247
def _process_ablated_out_full(
11931248
self,
11941249
modified_eval: Tensor,
1195-
current_mask: Tuple[Optional[Tensor], ...],
1250+
current_mask: List[List[Optional[Tensor]]],
11961251
flattened_initial_eval: Tensor,
1252+
initial_eval: Tensor,
11971253
inputs: TensorOrTupleOfTensorsGeneric,
11981254
n_outputs: int,
1255+
num_examples: int,
11991256
total_attrib: List[Tensor],
12001257
weights: List[Tensor],
12011258
attrib_type: dtype,
1259+
perturbations_per_eval: int,
12021260
) -> Tuple[List[Tensor], List[Tensor]]:
12031261
modified_eval = self._parse_forward_out(modified_eval)
1262+
# if perturbations_per_eval > 1, the output shape must grow with
1263+
# input and not be aggregated
1264+
current_batch_size = inputs[0].shape[0]
1265+
1266+
# number of perturbation, which is not the same as
1267+
# perturbations_per_eval when not enough features to perturb
1268+
n_perturb = current_batch_size / num_examples
1269+
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
1270+
1271+
current_output_shape = modified_eval.shape
1272+
1273+
# use initial_eval as the forward of perturbations_per_eval = 1
1274+
initial_output_shape = initial_eval.shape
1275+
1276+
assert (
1277+
# check if the output is not a scalar
1278+
current_output_shape
1279+
and initial_output_shape
1280+
# check if the output grow in same ratio, i.e., not agg
1281+
and current_output_shape[0] == n_perturb * initial_output_shape[0]
1282+
), (
1283+
"When perturbations_per_eval > 1, forward_func's output "
1284+
"should be a tensor whose 1st dim grow with the input "
1285+
f"batch size: when input batch size is {num_examples}, "
1286+
f"the output shape is {initial_output_shape}; "
1287+
f"when input batch size is {current_batch_size}, "
1288+
f"the output shape is {current_output_shape}"
1289+
)
1290+
1291+
self._is_output_shape_valid = True
12041292

12051293
# reshape the leading dim for n_feature_perturbed
12061294
# flatten each feature's eval outputs into 1D of (n_outputs)
12071295
modified_eval = modified_eval.reshape(-1, n_outputs)
12081296
# eval_diff in shape (n_feature_perturbed, n_outputs)
12091297
eval_diff = flattened_initial_eval - modified_eval
1210-
eval_diff_shape = eval_diff.shape
1211-
1212-
# append the shape of one input example
1213-
# to make it broadcastable to mask
12141298

1215-
if self.use_weights:
1216-
for weight, mask in zip(weights, current_mask):
1217-
if mask is not None:
1218-
weight += mask.float().sum(dim=0)
1219-
for i, mask in enumerate(current_mask):
1220-
if mask is None or inputs[i].numel() == 0:
1221-
continue
1222-
eval_diff = eval_diff.reshape(
1223-
eval_diff_shape + (inputs[i].dim() - 1) * (1,)
1224-
)
1225-
eval_diff = eval_diff.to(total_attrib[i].device)
1226-
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)
1299+
for j in range(int(n_perturb)):
1300+
single_perturb_mask = current_mask[j]
1301+
if self.use_weights:
1302+
for weight, mask in zip(weights, single_perturb_mask):
1303+
if mask is not None:
1304+
weight += mask.float()
1305+
for i, mask in enumerate(single_perturb_mask):
1306+
this_input = inputs[i][j * num_examples : (j + 1) * num_examples]
1307+
if mask is None or this_input.numel() == 0:
1308+
continue
1309+
eval_diff_j = eval_diff[j].reshape(
1310+
eval_diff[j].shape + (this_input.dim() - 1) * (1,)
1311+
)
1312+
eval_diff_j = eval_diff_j.to(total_attrib[i].device)
1313+
total_attrib[i] += eval_diff_j * mask.to(attrib_type)
12271314
return total_attrib, weights
12281315

12291316
def _fut_tuple_to_accumulate_fut_list(

tests/attr/test_feature_ablation.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ def test_multi_sample_ablation_with_mask(self) -> None:
164164
perturbations_per_eval=(1, 2, 3),
165165
)
166166

167+
def test_multi_sample_ablation_with_mask_weighted(self) -> None:
168+
ablation_algo = FeatureAblation(BasicModel_MultiLayer())
169+
ablation_algo.use_weights = True
170+
inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True)
171+
mask = torch.tensor([[0, 0, 1], [1, 1, 0]])
172+
self._ablation_test_assert(
173+
ablation_algo,
174+
inp,
175+
[[41.0, 41.0, 12.0], [280.0, 280.0, 120.0]],
176+
feature_mask=mask,
177+
perturbations_per_eval=(1, 2, 3),
178+
)
179+
167180
def test_multi_input_ablation_with_mask(self) -> None:
168181
ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput())
169182
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
@@ -207,6 +220,50 @@ def test_multi_input_ablation_with_mask(self) -> None:
207220
perturbations_per_eval=(1, 2, 3),
208221
)
209222

223+
def test_multi_input_ablation_with_mask_weighted(self) -> None:
224+
ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput())
225+
ablation_algo.use_weights = True
226+
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
227+
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
228+
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
229+
mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]])
230+
mask2 = torch.tensor([[3, 4, 2]])
231+
mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]])
232+
expected = (
233+
[[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]],
234+
[[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]],
235+
[[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]],
236+
)
237+
self._ablation_test_assert(
238+
ablation_algo,
239+
(inp1, inp2, inp3),
240+
expected,
241+
additional_input=(1,),
242+
feature_mask=(mask1, mask2, mask3),
243+
)
244+
self._ablation_test_assert(
245+
ablation_algo,
246+
(inp1, inp2),
247+
expected[0:1],
248+
additional_input=(inp3, 1),
249+
feature_mask=(mask1, mask2),
250+
perturbations_per_eval=(1, 2, 3),
251+
)
252+
expected_with_baseline = (
253+
[[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]],
254+
[[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]],
255+
[[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]],
256+
)
257+
self._ablation_test_assert(
258+
ablation_algo,
259+
(inp1, inp2, inp3),
260+
expected_with_baseline,
261+
additional_input=(1,),
262+
feature_mask=(mask1, mask2, mask3),
263+
baselines=(2, 3.0, 4),
264+
perturbations_per_eval=(1, 2, 3),
265+
)
266+
210267
def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None:
211268
ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput())
212269
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])

0 commit comments

Comments
 (0)