Skip to content

Commit d55c09f

Browse files
jjunchofacebook-github-bot
authored andcommitted
Future Async support for more test coverage (#1317)
Summary: Since Futures async is available for Feature Attributor, this diff expands the test cases that support/use FeatureAblation to cover all existing test cases in the diff. Reviewed By: cyrjano Differential Revision: D60118869
1 parent 430793e commit d55c09f

File tree

1 file changed

+187
-59
lines changed

1 file changed

+187
-59
lines changed

tests/attr/test_feature_permutation.py

Lines changed: 187 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation
6-
from parameterized import parameterized
76
from tests.helpers import BaseTest
87
from tests.helpers.basic import assertTensorAlmostEqual
98
from tests.helpers.basic_models import BasicModelWithSparseInputs
@@ -86,39 +85,59 @@ def test_perm_fn_broadcastable_masks(self) -> None:
8685

8786
self._check_perm_fn_with_mask(inp, mask)
8887

89-
@parameterized.expand([(True,), (False,)])
90-
def test_single_input(self, use_futures) -> None:
88+
def test_single_input(self) -> None:
9189
batch_size = 2
9290
input_size = (6,)
9391
constant_value = 10000
9492

9593
def forward_func(x: Tensor) -> Tensor:
9694
return x.sum(dim=-1)
9795

98-
if use_futures:
99-
feature_importance = FeaturePermutation(
100-
forward_func=self.construct_future_forward(forward_func)
101-
)
102-
feature_importance.use_futures = use_futures
96+
feature_importance = FeaturePermutation(forward_func=forward_func)
97+
98+
inp = torch.randn((batch_size,) + input_size)
99+
100+
inp[:, 0] = constant_value
101+
zeros = torch.zeros_like(inp[:, 0])
102+
103+
attribs = feature_importance.attribute(inp)
104+
105+
self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size)
106+
assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max")
107+
self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())
103108

104-
else:
105-
feature_importance = FeaturePermutation(forward_func=forward_func)
109+
def test_single_input_with_future(
110+
self,
111+
) -> None:
112+
batch_size = 2
113+
input_size = (6,)
114+
constant_value = 10000
115+
116+
def forward_func(x: Tensor) -> Tensor:
117+
return x.sum(dim=-1)
118+
119+
feature_importance = FeaturePermutation(
120+
forward_func=self.construct_future_forward(forward_func)
121+
)
122+
feature_importance.use_futures = True
106123

107124
inp = torch.randn((batch_size,) + input_size)
108125

109126
inp[:, 0] = constant_value
110127
zeros = torch.zeros_like(inp[:, 0])
111-
if use_futures:
112-
attribs = feature_importance.attribute_future(inp).wait()
113-
else:
114-
attribs = feature_importance.attribute(inp)
128+
129+
attribs = feature_importance.attribute_future(inp)
130+
131+
self.assertTrue(type(attribs) is torch.Future)
132+
attribs = attribs.wait()
115133

116134
self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size)
117135
assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max")
118136
self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())
119137

120-
@parameterized.expand([(True,), (False,)])
121-
def test_multi_input(self, use_futures) -> None:
138+
def test_multi_input(
139+
self,
140+
) -> None:
122141
batch_size = 20
123142
inp1_size = (5, 2)
124143
inp2_size = (5, 3)
@@ -133,14 +152,55 @@ def forward_func(*x: Tensor) -> Tensor:
133152

134153
return torch.mean((y - labels) ** 2)
135154

136-
if use_futures:
137-
feature_importance = FeaturePermutation(
138-
forward_func=self.construct_future_forward(forward_func)
139-
)
140-
feature_importance.use_futures = use_futures
155+
feature_importance = FeaturePermutation(forward_func=forward_func)
141156

142-
else:
143-
feature_importance = FeaturePermutation(forward_func=forward_func)
157+
inp = (
158+
torch.randn((batch_size,) + inp1_size),
159+
torch.randn((batch_size,) + inp2_size),
160+
)
161+
162+
feature_mask = (
163+
torch.arange(inp[0][0].numel()).view_as(inp[0][0]).unsqueeze(0),
164+
torch.arange(inp[1][0].numel()).view_as(inp[1][0]).unsqueeze(0),
165+
)
166+
167+
inp[1][:, :, 1] = 4
168+
169+
attribs = feature_importance.attribute(inp, feature_mask=feature_mask)
170+
171+
self.assertTrue(isinstance(attribs, tuple))
172+
self.assertTrue(len(attribs) == 2)
173+
174+
self.assertTrue(attribs[0].squeeze(0).size() == inp1_size)
175+
self.assertTrue(attribs[1].squeeze(0).size() == inp2_size)
176+
177+
self.assertTrue((attribs[1][:, :, 1] == 0).all())
178+
self.assertTrue((attribs[1][:, :, 2] == 0).all())
179+
180+
self.assertTrue((attribs[0] != 0).all())
181+
self.assertTrue((attribs[1][:, :, 0] != 0).all())
182+
183+
def test_multi_input_with_future(
184+
self,
185+
) -> None:
186+
batch_size = 20
187+
inp1_size = (5, 2)
188+
inp2_size = (5, 3)
189+
190+
labels = torch.randn(batch_size)
191+
192+
def forward_func(*x: Tensor) -> Tensor:
193+
y = torch.zeros(x[0].shape[0:2])
194+
for xx in x:
195+
y += xx[:, :, 0] * xx[:, :, 1]
196+
y = y.sum(dim=-1)
197+
198+
return torch.mean((y - labels) ** 2)
199+
200+
feature_importance = FeaturePermutation(
201+
forward_func=self.construct_future_forward(forward_func)
202+
)
203+
feature_importance.use_futures = True
144204

145205
inp = (
146206
torch.randn((batch_size,) + inp1_size),
@@ -154,12 +214,9 @@ def forward_func(*x: Tensor) -> Tensor:
154214

155215
inp[1][:, :, 1] = 4
156216

157-
if use_futures:
158-
attribs = feature_importance.attribute_future(
159-
inp, feature_mask=feature_mask
160-
).wait()
161-
else:
162-
attribs = feature_importance.attribute(inp, feature_mask=feature_mask)
217+
attribs = feature_importance.attribute_future(inp, feature_mask=feature_mask)
218+
self.assertTrue(type(attribs) is torch.Future)
219+
attribs = attribs.wait()
163220

164221
self.assertTrue(isinstance(attribs, tuple))
165222
self.assertTrue(len(attribs) == 2)
@@ -173,8 +230,9 @@ def forward_func(*x: Tensor) -> Tensor:
173230
self.assertTrue((attribs[0] != 0).all())
174231
self.assertTrue((attribs[1][:, :, 0] != 0).all())
175232

176-
@parameterized.expand([(True,), (False,)])
177-
def test_mulitple_perturbations_per_eval(self, use_futures) -> None:
233+
def test_multiple_perturbations_per_eval(
234+
self,
235+
) -> None:
178236
perturbations_per_eval = 4
179237
batch_size = 2
180238
input_size = (4,)
@@ -185,21 +243,51 @@ def forward_func(x: Tensor) -> Tensor:
185243
return 1 - x
186244

187245
target = 1
188-
if use_futures:
189-
feature_importance = FeaturePermutation(
190-
forward_func=self.construct_future_forward(forward_func)
191-
)
192-
feature_importance.use_futures = use_futures
193-
attribs = feature_importance.attribute_future(
194-
inp, perturbations_per_eval=perturbations_per_eval, target=target
195-
).wait()
196-
else:
197-
feature_importance = FeaturePermutation(forward_func=forward_func)
198-
199-
attribs = feature_importance.attribute(
200-
inp, perturbations_per_eval=perturbations_per_eval, target=target
246+
247+
feature_importance = FeaturePermutation(forward_func=forward_func)
248+
249+
attribs = feature_importance.attribute(
250+
inp, perturbations_per_eval=perturbations_per_eval, target=target
251+
)
252+
253+
self.assertTrue(attribs.size() == (batch_size,) + input_size)
254+
255+
for i in range(inp.size(1)):
256+
if i == target:
257+
continue
258+
assertTensorAlmostEqual(
259+
self, attribs[:, i], torch.zeros_like(attribs[:, i])
201260
)
202261

262+
y = forward_func(inp)
263+
actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]])
264+
assertTensorAlmostEqual(self, attribs[:, target], actual_diff)
265+
266+
def test_multiple_perturbations_per_eval_with_futures(
267+
self,
268+
) -> None:
269+
perturbations_per_eval = 4
270+
batch_size = 2
271+
input_size = (4,)
272+
273+
inp = torch.randn((batch_size,) + input_size)
274+
275+
def forward_func(x: Tensor) -> Tensor:
276+
return 1 - x
277+
278+
target = 1
279+
280+
feature_importance = FeaturePermutation(
281+
forward_func=self.construct_future_forward(forward_func)
282+
)
283+
feature_importance.use_futures = True
284+
285+
attribs = feature_importance.attribute_future(
286+
inp, perturbations_per_eval=perturbations_per_eval, target=target
287+
)
288+
self.assertTrue(type(attribs) is torch.Future)
289+
attribs = attribs.wait()
290+
203291
self.assertTrue(attribs.size() == (batch_size,) + input_size)
204292

205293
for i in range(inp.size(1)):
@@ -213,22 +301,18 @@ def forward_func(x: Tensor) -> Tensor:
213301
actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]])
214302
assertTensorAlmostEqual(self, attribs[:, target], actual_diff)
215303

216-
@parameterized.expand([(True,), (False,)])
217-
def test_broadcastable_masks(self, use_futures) -> None:
304+
def test_broadcastable_masks(
305+
self,
306+
) -> None:
218307
# integration test to ensure that
219308
# permutation function works with custom masks
220309
def forward_func(x: Tensor) -> Tensor:
221310
return x.view(x.shape[0], -1).sum(dim=-1)
222311

223312
batch_size = 2
224313
inp = torch.randn((batch_size,) + (3, 4, 4))
225-
if use_futures:
226-
feature_importance = FeaturePermutation(
227-
forward_func=self.construct_future_forward(forward_func)
228-
)
229-
feature_importance.use_futures = use_futures
230-
else:
231-
feature_importance = FeaturePermutation(forward_func=forward_func)
314+
315+
feature_importance = FeaturePermutation(forward_func=forward_func)
232316

233317
masks = [
234318
torch.tensor([0]),
@@ -237,12 +321,8 @@ def forward_func(x: Tensor) -> Tensor:
237321
]
238322

239323
for mask in masks:
240-
if use_futures:
241-
attribs = feature_importance.attribute_future(
242-
inp, feature_mask=mask
243-
).wait()
244-
else:
245-
attribs = feature_importance.attribute(inp, feature_mask=mask)
324+
325+
attribs = feature_importance.attribute(inp, feature_mask=mask)
246326
self.assertTrue(attribs is not None)
247327
self.assertTrue(attribs.shape == inp.shape)
248328

@@ -260,6 +340,54 @@ def forward_func(x: Tensor) -> Tensor:
260340
mode="max",
261341
)
262342

343+
def test_broadcastable_masks_with_future(
344+
self,
345+
) -> None:
346+
# integration test to ensure that
347+
# permutation function works with custom masks
348+
def forward_func(x: Tensor) -> Tensor:
349+
return x.view(x.shape[0], -1).sum(dim=-1)
350+
351+
batch_size = 2
352+
inp = torch.randn((batch_size,) + (3, 4, 4))
353+
354+
feature_importance = FeaturePermutation(
355+
forward_func=self.construct_future_forward(forward_func)
356+
)
357+
feature_importance.use_futures = True
358+
359+
masks = [
360+
torch.tensor([0]),
361+
torch.tensor([[0, 1, 2, 3]]),
362+
torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]),
363+
]
364+
365+
results = []
366+
367+
for mask in masks:
368+
attribs_future = feature_importance.attribute_future(inp, feature_mask=mask)
369+
results.append(attribs_future)
370+
self.assertTrue(attribs_future is not None)
371+
372+
for idx in range(len(results)):
373+
attribs = results[idx].wait()
374+
self.assertTrue(attribs is not None)
375+
self.assertTrue(attribs.shape == inp.shape)
376+
377+
fm = masks[idx].expand_as(inp[0])
378+
379+
features = set(masks[idx].flatten())
380+
for feature in features:
381+
m = (fm == feature).bool()
382+
attribs_for_feature = attribs[:, m]
383+
assertTensorAlmostEqual(
384+
self,
385+
attribs_for_feature[0],
386+
-attribs_for_feature[1],
387+
delta=0.05,
388+
mode="max",
389+
)
390+
263391
def test_empty_sparse_features(self) -> None:
264392
model = BasicModelWithSparseInputs()
265393
inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])

0 commit comments

Comments
 (0)