Skip to content

Commit 8269ba5

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Adjust indices in LayerAttributor mask for individual neurons (#1531)
Summary: With `enable_cross_tensor_attribution=True` for `FeatureAblation`/`FeaturePermutation`, ids/indices in the masks are now "global" Reviewed By: cyrjano Differential Revision: D71778355
1 parent 8723706 commit 8269ba5

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

captum/_utils/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,18 @@ def _is_mask_valid(mask: Tensor, inp: Tensor) -> bool:
194194
def _format_feature_mask(
195195
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
196196
inputs: Tuple[Tensor, ...],
197+
start_idx: int = 0,
197198
) -> Tuple[Tensor, ...]:
198199
"""
199200
Format a feature mask into a tuple of tensors.
200201
The `inputs` should be correctly formatted first
201202
If `feature_mask` is None, assign each non-batch dimension with a consecutive
202-
integer from 0.
203+
integer from `start_idx`.
203204
If `feature_mask` is a tensor, wrap it in a tuple.
204205
"""
205206
if feature_mask is None:
206207
formatted_mask = []
207-
current_num_features = 0
208+
current_num_features = start_idx
208209
for inp in inputs:
209210
# the following can handle empty tensor where numel is 0
210211
# empty tensor will be added to the feature mask

captum/testing/helpers/basic_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,16 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int):
662662
return self.model(scale * (x1 + x2 + x3))
663663

664664

665+
class BasicModel_MultiLayer_TupleInput(nn.Module):
666+
def __init__(self) -> None:
667+
super().__init__()
668+
self.model = BasicModel_MultiLayer()
669+
670+
@no_type_check
671+
def forward(self, x: Tuple[Tensor, Tensor, Tensor]) -> Tensor:
672+
return self.model(x[0] + x[1] + x[2])
673+
674+
665675
class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module):
666676
def __init__(self) -> None:
667677
super().__init__()

tests/attr/neuron/test_neuron_ablation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def test_multi_input_ablation_with_mask(self) -> None:
8383
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
8484
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
8585
mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]])
86-
mask2 = torch.tensor([[0, 1, 2]])
87-
mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]])
86+
mask2 = torch.tensor([[3, 4, 2]])
87+
mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]])
8888
expected = (
8989
[[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]],
9090
[[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]],

0 commit comments

Comments
 (0)