Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 45 additions & 38 deletions tests/attr/layer/test_layer_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def test_simple_ablation_with_mask(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
self._ablation_test_assert(
net,
net.linear0,
inp,
([280.0, 280.0, 120.0],),
model=net,
layer=net.linear0,
test_input=inp,
expected_ablation=([280.0, 280.0, 120.0],),
layer_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
attribute_to_layer_input=True,
Expand All @@ -37,20 +37,20 @@ def test_multi_input_ablation(self) -> None:
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
baseline = torch.tensor([[1.0, 2.0, 3.0]])
self._ablation_test_assert(
net,
net.model.linear1,
(inp1, inp2, inp3),
[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
model=net,
layer=net.model.linear1,
test_input=(inp1, inp2, inp3),
expected_ablation=[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
additional_input=(1,),
baselines=baseline,
perturbations_per_eval=(1, 2, 3),
attribute_to_layer_input=True,
)
self._ablation_test_assert(
net,
net.model.linear0,
(inp1, inp2, inp3),
[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
model=net,
layer=net.model.linear0,
test_input=(inp1, inp2, inp3),
expected_ablation=[[168.0, 992.0, 148.0], [84.0, 632.0, 120.0]],
additional_input=(1,),
baselines=baseline,
perturbations_per_eval=(1, 2, 3),
Expand All @@ -65,21 +65,21 @@ def test_multi_input_ablation_with_layer_mask(self) -> None:
baseline = torch.tensor([[1.0, 2.0, 3.0]])
layer_mask = torch.tensor([[0, 1, 0], [0, 1, 2]])
self._ablation_test_assert(
net,
net.model.linear1,
(inp1, inp2, inp3),
[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
model=net,
layer=net.model.linear1,
test_input=(inp1, inp2, inp3),
expected_ablation=[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
additional_input=(1,),
baselines=baseline,
perturbations_per_eval=(1, 2, 3),
layer_mask=layer_mask,
attribute_to_layer_input=True,
)
self._ablation_test_assert(
net,
net.model.linear0,
(inp1, inp2, inp3),
[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
model=net,
layer=net.model.linear0,
test_input=(inp1, inp2, inp3),
expected_ablation=[[316.0, 992.0, 316.0], [84.0, 632.0, 120.0]],
additional_input=(1,),
baselines=baseline,
layer_mask=layer_mask,
Expand All @@ -91,28 +91,32 @@ def test_simple_multi_input_conv_intermediate(self) -> None:
inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4)
inp2 = torch.ones((1, 1, 4, 4))
self._ablation_test_assert(
net,
net.relu1,
(inp, inp2),
[[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]]],
model=net,
layer=net.relu1,
test_input=(inp, inp2),
expected_ablation=[[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]]],
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
)
self._ablation_test_assert(
net,
net.relu1,
(inp, inp2),
([[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]],),
model=net,
layer=net.relu1,
test_input=(inp, inp2),
expected_ablation=(
[[[4.0, 13.0], [40.0, 49.0]], [[0, 0], [-15.0, -24.0]]],
),
baselines=torch.tensor(
[[[-4.0, -13.0], [-2.0, -2.0]], [[0, 0], [0.0, 0.0]]]
),
perturbations_per_eval=(1, 2, 4, 8, 12, 16),
attribute_to_layer_input=True,
)
self._ablation_test_assert(
net,
net.relu1,
(inp, inp2),
[[[[17.0, 17.0], [67.0, 67.0]], [[0, 0], [-39.0, -39.0]]]],
model=net,
layer=net.relu1,
test_input=(inp, inp2),
expected_ablation=[
[[[17.0, 17.0], [67.0, 67.0]], [[0, 0], [-39.0, -39.0]]]
],
perturbations_per_eval=(1, 2, 4),
layer_mask=torch.tensor([[[[0, 0], [1, 1]], [[2, 2], [3, 3]]]]),
)
Expand All @@ -121,17 +125,20 @@ def test_simple_multi_output_ablation(self) -> None:
net = BasicModel_MultiLayer(multi_input_module=True)
inp = torch.tensor([[0.0, 6.0, 0.0]])
self._ablation_test_assert(
net, net.multi_relu, inp, ([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]])
model=net,
layer=net.multi_relu,
test_input=inp,
expected_ablation=([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
)

def test_simple_multi_output_input_ablation(self) -> None:
net = BasicModel_MultiLayer(multi_input_module=True)
inp = torch.tensor([[0.0, 6.0, 0.0]])
self._ablation_test_assert(
net,
net.multi_relu,
inp,
([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
model=net,
layer=net.multi_relu,
test_input=inp,
expected_ablation=([[0.0, 7.0, 7.0, 7.0]], [[0.0, 7.0, 7.0, 7.0]]),
attribute_to_layer_input=True,
)

Expand All @@ -151,7 +158,7 @@ def _ablation_test_assert(
for batch_size in perturbations_per_eval:
ablation = LayerFeatureAblation(model, layer)
attributions = ablation.attribute(
test_input,
inputs=test_input,
target=target,
layer_mask=layer_mask,
additional_forward_args=additional_input,
Expand Down
33 changes: 33 additions & 0 deletions tests/attr/layer/test_layer_feature_permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import torch
from captum.attr._core.layer.layer_feature_permutation import LayerFeaturePermutation
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_MultiLayer
from torch import Tensor


class TestLayerFeaturePermutation(BaseTest):
def test_single_input(self) -> None:
net = BasicModel_MultiLayer()
feature_importance = LayerFeaturePermutation(
forward_func=net,
layer=net.linear0,
)

batch_size = 2
input_size = (3,)
constant_value = 10000

inp = torch.randn((batch_size,) + input_size)
inp[:, 0] = constant_value

attribs = feature_importance.attribute(inputs=inp)

self.assertTrue(isinstance(attribs, Tensor))
self.assertEqual(len(attribs), 4)
self.assertEqual(attribs.squeeze(0).size(), (2 * batch_size,) + input_size)
zeros = torch.zeros(2 * batch_size)
assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0, mode="max")
self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())