Skip to content

Commit b1809a3

Browse files
jjunchofacebook-github-bot
authored andcommitted
Add Future async support to FeaturePermutation (#1316)
Summary: Pull Request resolved: #1316 This diff adds support for pytorch futures to FeaturePermutation and adjusts the unit tests to support that Differential Revision: D60405087
1 parent 3543414 commit b1809a3

File tree

2 files changed

+101
-19
lines changed

2 files changed

+101
-19
lines changed

captum/attr/_core/feature_permutation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from captum.attr._core.feature_ablation import FeatureAblation
77
from captum.log import log_usage
88
from torch import Tensor
9+
from torch.futures import Future
910

1011

1112
def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
@@ -86,6 +87,7 @@ def __init__(
8687
"""
8788
FeatureAblation.__init__(self, forward_func=forward_func)
8889
self.perm_func = perm_func
90+
self.use_futures = False
8991

9092
# suppressing error caused by the child class not having a matching
9193
# signature to the parent
@@ -271,6 +273,31 @@ def attribute( # type: ignore
271273
**kwargs,
272274
)
273275

276+
def attribute_future(
277+
self,
278+
inputs: TensorOrTupleOfTensorsGeneric,
279+
target: TargetType = None,
280+
additional_forward_args: Any = None,
281+
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
282+
perturbations_per_eval: int = 1,
283+
show_progress: bool = False,
284+
**kwargs: Any,
285+
) -> Future[TensorOrTupleOfTensorsGeneric]:
286+
if isinstance(kwargs, dict) and "baselines" in kwargs:
287+
del kwargs["baselines"]
288+
return FeatureAblation.attribute.__wrapped__(
289+
self,
290+
inputs,
291+
baselines=None,
292+
target=target,
293+
additional_forward_args=additional_forward_args,
294+
feature_mask=feature_mask,
295+
perturbations_per_eval=perturbations_per_eval,
296+
show_progress=show_progress,
297+
use_futures=self.use_futures,
298+
**kwargs,
299+
)
300+
274301
def _construct_ablated_input(
275302
self,
276303
expanded_input: Tensor,

tests/attr/test_feature_permutation.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
#!/usr/bin/env python3
2-
from typing import List, Tuple
2+
from typing import Callable, List, Tuple
33

44
import torch
55
from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation
6+
from parameterized import parameterized
67
from tests.helpers import BaseTest
78
from tests.helpers.basic import assertTensorAlmostEqual
89
from tests.helpers.basic_models import BasicModelWithSparseInputs
910
from torch import Tensor
1011

1112

13+
# pyre-ignore Undefined attribute [13]
1214
class Test(BaseTest):
15+
def construct_future_forward(self, original_forward: Callable) -> Callable:
16+
def future_forward(*args, **kwargs):
17+
fut = torch.futures.Future()
18+
fut.set_result(original_forward(*args, **kwargs))
19+
return fut
20+
21+
return future_forward
22+
1323
def _check_features_are_permuted(
1424
self, inp: Tensor, perm_inp: Tensor, mask: Tensor
1525
) -> None:
@@ -76,28 +86,39 @@ def test_perm_fn_broadcastable_masks(self) -> None:
7686

7787
self._check_perm_fn_with_mask(inp, mask)
7888

79-
def test_single_input(self) -> None:
89+
@parameterized.expand([(True,), (False,)])
90+
def test_single_input(self, use_futures) -> None:
8091
batch_size = 2
8192
input_size = (6,)
8293
constant_value = 10000
8394

8495
def forward_func(x: Tensor) -> Tensor:
8596
return x.sum(dim=-1)
8697

87-
feature_importance = FeaturePermutation(forward_func=forward_func)
98+
if use_futures:
99+
feature_importance = FeaturePermutation(
100+
forward_func=self.construct_future_forward(forward_func)
101+
)
102+
feature_importance.use_futures = use_futures
103+
104+
else:
105+
feature_importance = FeaturePermutation(forward_func=forward_func)
88106

89107
inp = torch.randn((batch_size,) + input_size)
90108

91109
inp[:, 0] = constant_value
92110
zeros = torch.zeros_like(inp[:, 0])
93-
94-
attribs = feature_importance.attribute(inp)
111+
if use_futures:
112+
attribs = feature_importance.attribute_future(inp).wait()
113+
else:
114+
attribs = feature_importance.attribute(inp)
95115

96116
self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size)
97117
assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max")
98118
self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())
99119

100-
def test_multi_input(self) -> None:
120+
@parameterized.expand([(True,), (False,)])
121+
def test_multi_input(self, use_futures) -> None:
101122
batch_size = 20
102123
inp1_size = (5, 2)
103124
inp2_size = (5, 3)
@@ -112,7 +133,14 @@ def forward_func(*x: Tensor) -> Tensor:
112133

113134
return torch.mean((y - labels) ** 2)
114135

115-
feature_importance = FeaturePermutation(forward_func=forward_func)
136+
if use_futures:
137+
feature_importance = FeaturePermutation(
138+
forward_func=self.construct_future_forward(forward_func)
139+
)
140+
feature_importance.use_futures = use_futures
141+
142+
else:
143+
feature_importance = FeaturePermutation(forward_func=forward_func)
116144

117145
inp = (
118146
torch.randn((batch_size,) + inp1_size),
@@ -125,7 +153,13 @@ def forward_func(*x: Tensor) -> Tensor:
125153
)
126154

127155
inp[1][:, :, 1] = 4
128-
attribs = feature_importance.attribute(inp, feature_mask=feature_mask)
156+
157+
if use_futures:
158+
attribs = feature_importance.attribute_future(
159+
inp, feature_mask=feature_mask
160+
).wait()
161+
else:
162+
attribs = feature_importance.attribute(inp, feature_mask=feature_mask)
129163

130164
self.assertTrue(isinstance(attribs, tuple))
131165
self.assertTrue(len(attribs) == 2)
@@ -139,22 +173,33 @@ def forward_func(*x: Tensor) -> Tensor:
139173
self.assertTrue((attribs[0] != 0).all())
140174
self.assertTrue((attribs[1][:, :, 0] != 0).all())
141175

142-
def test_mulitple_perturbations_per_eval(self) -> None:
176+
@parameterized.expand([(True,), (False,)])
177+
def test_mulitple_perturbations_per_eval(self, use_futures) -> None:
143178
perturbations_per_eval = 4
144179
batch_size = 2
145180
input_size = (4,)
146181

147182
inp = torch.randn((batch_size,) + input_size)
148183

149-
def forward_func(x):
184+
def forward_func(x: Tensor) -> Tensor:
150185
return 1 - x
151186

152187
target = 1
153-
feature_importance = FeaturePermutation(forward_func=forward_func)
188+
if use_futures:
189+
feature_importance = FeaturePermutation(
190+
forward_func=self.construct_future_forward(forward_func)
191+
)
192+
feature_importance.use_futures = use_futures
193+
attribs = feature_importance.attribute_future(
194+
inp, perturbations_per_eval=perturbations_per_eval, target=target
195+
).wait()
196+
else:
197+
feature_importance = FeaturePermutation(forward_func=forward_func)
198+
199+
attribs = feature_importance.attribute(
200+
inp, perturbations_per_eval=perturbations_per_eval, target=target
201+
)
154202

155-
attribs = feature_importance.attribute(
156-
inp, perturbations_per_eval=perturbations_per_eval, target=target
157-
)
158203
self.assertTrue(attribs.size() == (batch_size,) + input_size)
159204

160205
for i in range(inp.size(1)):
@@ -168,16 +213,22 @@ def forward_func(x):
168213
actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]])
169214
assertTensorAlmostEqual(self, attribs[:, target], actual_diff)
170215

171-
def test_broadcastable_masks(self) -> None:
216+
@parameterized.expand([(True,), (False,)])
217+
def test_broadcastable_masks(self, use_futures) -> None:
172218
# integration test to ensure that
173219
# permutation function works with custom masks
174220
def forward_func(x: Tensor) -> Tensor:
175221
return x.view(x.shape[0], -1).sum(dim=-1)
176222

177223
batch_size = 2
178224
inp = torch.randn((batch_size,) + (3, 4, 4))
179-
180-
feature_importance = FeaturePermutation(forward_func=forward_func)
225+
if use_futures:
226+
feature_importance = FeaturePermutation(
227+
forward_func=self.construct_future_forward(forward_func)
228+
)
229+
feature_importance.use_futures = use_futures
230+
else:
231+
feature_importance = FeaturePermutation(forward_func=forward_func)
181232

182233
masks = [
183234
torch.tensor([0]),
@@ -186,8 +237,12 @@ def forward_func(x: Tensor) -> Tensor:
186237
]
187238

188239
for mask in masks:
189-
attribs = feature_importance.attribute(inp, feature_mask=mask)
190-
240+
if use_futures:
241+
attribs = feature_importance.attribute_future(
242+
inp, feature_mask=mask
243+
).wait()
244+
else:
245+
attribs = feature_importance.attribute(inp, feature_mask=mask)
191246
self.assertTrue(attribs is not None)
192247
self.assertTrue(attribs.shape == inp.shape)
193248

0 commit comments

Comments
 (0)