Skip to content

Commit 3543414

Browse files
yucufacebook-github-bot
authored andcommitted
Initial version of async attribution with torch.futures (#1295)
Summary: Pull Request resolved: #1295 Currently Captum doesn't support async forward functions. Ads R&P team would like this feature in order to replace their custom variant (D56655643) of Feature Ablation with Captum and maintain similar performance. PyTorch introduce future concepts ([link](https://pytorch.org/docs/stable/futures.html)) so we can adopt it for feature_ablation.py as the first step. Details: - Initial evaluation returns a future, save it. - Each evaluation for each feature for each input will returns an attribution result (plus corresponding weight if applicable), save all those result separately since futures cannot be added up directly. - When all futures above are done. we can add up the evaluation result to the final outcome as one Tensor per input. - Since common._run_forward is used by other attribution methods, need to do some type hacking over there. But if users attempt to use those methods async, they will end up in failure before Captum support async for those methods. TODO: Extend FeatureAttributor to support `torch.futures` Reviewed By: vivekmig Differential Revision: D56764316 fbshipit-source-id: 33661a76380dc009f4c9c60323b3c584e5152cda
1 parent 3125f59 commit 3543414

File tree

11 files changed

+401
-115
lines changed

11 files changed

+401
-115
lines changed

captum/_utils/common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
TupleOrTensorOrBoolGeneric,
1616
)
1717
from torch import device, Tensor
18+
19+
from torch.futures import Future
1820
from torch.nn import Module
1921

2022

@@ -514,7 +516,7 @@ def _run_forward(
514516
inputs: Any,
515517
target: TargetType = None,
516518
additional_forward_args: Any = None,
517-
) -> Tensor:
519+
) -> Union[Tensor, Future[Tensor]]:
518520
forward_func_args = signature(forward_func).parameters
519521
if len(forward_func_args) == 0:
520522
output = forward_func()
@@ -532,6 +534,8 @@ def _run_forward(
532534
else inputs
533535
)
534536
)
537+
if isinstance(output, torch.futures.Future):
538+
return output.then(lambda x: _select_targets(x.value(), target))
535539
return _select_targets(output, target)
536540

537541

captum/_utils/gradient.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def compute_gradients(
112112
with torch.autograd.set_grad_enabled(True):
113113
# runs forward pass
114114
outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
115+
# _run_forward may return future of Tensor,
116+
# but we don't support it here now
117+
# And it will fail before here.
118+
outputs = cast(Tensor, outputs)
115119
assert outputs[0].numel() == 1, (
116120
"Target not provided when necessary, cannot"
117121
" take gradient with respect to multiple outputs."
@@ -297,6 +301,10 @@ def forward_hook(module, inp, out=None):
297301
target=target_ind,
298302
additional_forward_args=additional_forward_args,
299303
)
304+
# _run_forward may return future of Tensor,
305+
# but we don't support it here now
306+
# And it will fail before here.
307+
output = cast(Tensor, output)
300308
finally:
301309
for hook in all_hooks:
302310
hook.remove()

0 commit comments

Comments
 (0)