Skip to content

Commit 6cb3ef6

Browse files
yucufacebook-github-bot
authored andcommitted
Basic test for LayerFeaturePermutation (#1272)
Summary: Pull Request resolved: #1272 Add minimum test for LayerFeaturePermutation Will add more when switching to use customized attribute method for LayerFeaturePermutation rather than inheriting. Reviewed By: vivekmig Differential Revision: D56051183 fbshipit-source-id: 92c97eefacd7f88d55f5afe373d3361e67822e8d
1 parent 8b9983d commit 6cb3ef6

File tree

2 files changed

+78
-38
lines changed

2 files changed

+78
-38
lines changed

tests/attr/layer/test_layer_ablation.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def test_simple_ablation_with_mask(self) -> None:
2121
net = BasicModel_MultiLayer()
2222
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
2323
self._ablation_test_assert(
24-
net,
25-
net.linear0,
26-
inp,
27-
([280.0, 280.0, 120.0],),
24+
model=net,
25+
layer=net.linear0,
26+
test_input=inp,
27+
expected_ablation=([280.0, 280.0, 120.0],),
2828
layer_mask=torch.tensor([[0, 0, 1]]),
2929
perturbations_per_eval=(1, 2, 3),
3030
attribute_to_layer_input=True,
@@ -37,20 +37,20 @@ def test_multi_input_ablation(self) -> None:
3737
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
3838
baseline = torch.tensor([[1.0, 2.0, 3.0]])
3939
self._ablation_test_assert(
40-
net,
41-
net.model.linear1,
42-
(inp1, inp2, inp3),
43-
[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
40+
model=net,
41+
layer=net.model.linear1,
42+
test_input=(inp1, inp2, inp3),
43+
expected_ablation=[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
4444
additional_input=(1,),
4545
baselines=baseline,
4646
perturbations_per_eval=(1, 2, 3),
4747
attribute_to_layer_input=True,
4848
)
4949
self._ablation_test_assert(
50-
net,
51-
net.model.linear0,
52-
(inp1, inp2, inp3),
53-
[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
50+
model=net,
51+
layer=net.model.linear0,
52+
test_input=(inp1, inp2, inp3),
53+
expected_ablation=[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
5454
additional_input=(1,),
5555
baselines=baseline,
5656
perturbations_per_eval=(1, 2, 3),
@@ -65,21 +65,21 @@ def test_multi_input_ablation_with_layer_mask(self) -> None:
6565
baseline = torch.tensor([[1.0, 2.0, 3.0]])
6666
layer_mask = torch.tensor([[0, 1, 0], [0, 1, 2]])
6767
self._ablation_test_assert(
68-
net,
69-
net.model.linear1,
70-
(inp1, inp2, inp3),
71-
[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
68+
model=net,
69+
layer=net.model.linear1,
70+
test_input=(inp1, inp2, inp3),
71+
expected_ablation=[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
7272
additional_input=(1,),
7373
baselines=baseline,
7474
perturbations_per_eval=(1, 2, 3),
7575
layer_mask=layer_mask,
7676
attribute_to_layer_input=True,
7777
)
7878
self._ablation_test_assert(
79-
net,
80-
net.model.linear0,
81-
(inp1, inp2, inp3),
82-
[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
79+
model=net,
80+
layer=net.model.linear0,
81+
test_input=(inp1, inp2, inp3),
82+
expected_ablation=[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
8383
additional_input=(1,),
8484
baselines=baseline,
8585
layer_mask=layer_mask,
@@ -91,28 +91,32 @@ def test_simple_multi_input_conv_intermediate(self) -> None:
9191
inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4)
9292
inp2 = torch.ones((1, 1, 4, 4))
9393
self._ablation_test_assert(
94-
net,
95-
net.relu1,
96-
(inp, inp2),
97-
[[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]]],
94+
model=net,
95+
layer=net.relu1,
96+
test_input=(inp, inp2),
97+
expected_ablation=[[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]]],
9898
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
9999
)
100100
self._ablation_test_assert(
101-
net,
102-
net.relu1,
103-
(inp, inp2),
104-
([[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]],),
101+
model=net,
102+
layer=net.relu1,
103+
test_input=(inp, inp2),
104+
expected_ablation=(
105+
[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]],
106+
),
105107
baselines=torch.tensor(
106108
[[[-4.0, -13.0], [-2.0, -2.0]], [[0, 0], [0.0, 0.0]]]
107109
),
108110
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
109111
attribute_to_layer_input=True,
110112
)
111113
self._ablation_test_assert(
112-
net,
113-
net.relu1,
114-
(inp, inp2),
115-
[[[[17.0, 17.0], [67.0, 67.0]], [[0, 0], [-39.0, -39.0]]]],
114+
model=net,
115+
layer=net.relu1,
116+
test_input=(inp, inp2),
117+
expected_ablation=[
118+
[[[17.0, 17.0], [67.0, 67.0]], [[0, 0], [-39.0, -39.0]]]
119+
],
116120
perturbations_per_eval=(1, 2, 4),
117121
layer_mask=torch.tensor([[[[0, 0], [1, 1]], [[2, 2], [3, 3]]]]),
118122
)
@@ -121,17 +125,20 @@ def test_simple_multi_output_ablation(self) -> None:
121125
net = BasicModel_MultiLayer(multi_input_module=True)
122126
inp = torch.tensor([[0.0, 6.0, 0.0]])
123127
self._ablation_test_assert(
124-
net, net.multi_relu, inp, ([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]])
128+
model=net,
129+
layer=net.multi_relu,
130+
test_input=inp,
131+
expected_ablation=([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
125132
)
126133

127134
def test_simple_multi_output_input_ablation(self) -> None:
128135
net = BasicModel_MultiLayer(multi_input_module=True)
129136
inp = torch.tensor([[0.0, 6.0, 0.0]])
130137
self._ablation_test_assert(
131-
net,
132-
net.multi_relu,
133-
inp,
134-
([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
138+
model=net,
139+
layer=net.multi_relu,
140+
test_input=inp,
141+
expected_ablation=([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
135142
attribute_to_layer_input=True,
136143
)
137144

@@ -151,7 +158,7 @@ def _ablation_test_assert(
151158
for batch_size in perturbations_per_eval:
152159
ablation = LayerFeatureAblation(model, layer)
153160
attributions = ablation.attribute(
154-
test_input,
161+
inputs=test_input,
155162
target=target,
156163
layer_mask=layer_mask,
157164
additional_forward_args=additional_input,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import torch
4+
from captum.attr._core.layer.layer_feature_permutation import LayerFeaturePermutation
5+
from tests.helpers import BaseTest
6+
from tests.helpers.basic import assertTensorAlmostEqual
7+
from tests.helpers.basic_models import BasicModel_MultiLayer
8+
from torch import Tensor
9+
10+
11+
class TestLayerFeaturePermutation(BaseTest):
12+
def test_single_input(self) -> None:
13+
net = BasicModel_MultiLayer()
14+
feature_importance = LayerFeaturePermutation(
15+
forward_func=net,
16+
layer=net.linear0,
17+
)
18+
19+
batch_size = 2
20+
input_size = (3,)
21+
constant_value = 10000
22+
23+
inp = torch.randn((batch_size,) + input_size)
24+
inp[:, 0] = constant_value
25+
26+
attribs = feature_importance.attribute(inputs=inp)
27+
28+
self.assertTrue(isinstance(attribs, Tensor))
29+
self.assertEqual(len(attribs), 4)
30+
self.assertEqual(attribs.squeeze(0).size(), (2 * batch_size,) + input_size)
31+
zeros = torch.zeros(2 * batch_size)
32+
assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0, mode="max")
33+
self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())

0 commit comments

Comments
 (0)