diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index ba23ad4ec6..a3e7580780 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -6,6 +6,7 @@ from captum.attr._core.feature_ablation import FeatureAblation from captum.log import log_usage from torch import Tensor +from torch.futures import Future def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: @@ -86,6 +87,7 @@ def __init__( """ FeatureAblation.__init__(self, forward_func=forward_func) self.perm_func = perm_func + self.use_futures = False # suppressing error caused by the child class not having a matching # signature to the parent @@ -271,6 +273,31 @@ def attribute( # type: ignore **kwargs, ) + def attribute_future( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + perturbations_per_eval: int = 1, + show_progress: bool = False, + **kwargs: Any, + ) -> Future[TensorOrTupleOfTensorsGeneric]: + if isinstance(kwargs, dict) and "baselines" in kwargs: + del kwargs["baselines"] + return FeatureAblation.attribute.__wrapped__( + self, + inputs, + baselines=None, + target=target, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + perturbations_per_eval=perturbations_per_eval, + show_progress=show_progress, + use_futures=self.use_futures, + **kwargs, + ) + def _construct_ablated_input( self, expanded_input: Tensor, diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index e432cd1a2c..c28ded3850 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -1,15 +1,25 @@ #!/usr/bin/env python3 -from typing import List, Tuple +from typing import Callable, List, Tuple import torch from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation +from parameterized import parameterized from tests.helpers import BaseTest from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModelWithSparseInputs from torch import Tensor +# pyre-ignore Undefined attribute [13] class Test(BaseTest): + def construct_future_forward(self, original_forward: Callable) -> Callable: + def future_forward(*args, **kwargs): + fut = torch.futures.Future() + fut.set_result(original_forward(*args, **kwargs)) + return fut + + return future_forward + def _check_features_are_permuted( self, inp: Tensor, perm_inp: Tensor, mask: Tensor ) -> None: @@ -76,7 +86,8 @@ def test_perm_fn_broadcastable_masks(self) -> None: self._check_perm_fn_with_mask(inp, mask) - def test_single_input(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_single_input(self, use_futures) -> None: batch_size = 2 input_size = (6,) constant_value = 10000 @@ -84,20 +95,30 @@ def test_single_input(self) -> None: def forward_func(x: Tensor) -> Tensor: return x.sum(dim=-1) - feature_importance = FeaturePermutation(forward_func=forward_func) + if use_futures: + feature_importance = FeaturePermutation( + forward_func=self.construct_future_forward(forward_func) + ) + feature_importance.use_futures = use_futures + + else: + feature_importance = FeaturePermutation(forward_func=forward_func) inp = torch.randn((batch_size,) + input_size) inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - - attribs = feature_importance.attribute(inp) + if use_futures: + attribs = feature_importance.attribute_future(inp).wait() + else: + attribs = feature_importance.attribute(inp) self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) - def test_multi_input(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_multi_input(self, use_futures) -> None: batch_size = 20 inp1_size = (5, 2) inp2_size = (5, 3) @@ -112,7 +133,14 @@ def forward_func(*x: Tensor) -> Tensor: return torch.mean((y - labels) ** 2) - feature_importance = FeaturePermutation(forward_func=forward_func) + if use_futures: + feature_importance = FeaturePermutation( + forward_func=self.construct_future_forward(forward_func) + ) + feature_importance.use_futures = use_futures + + else: + feature_importance = FeaturePermutation(forward_func=forward_func) inp = ( torch.randn((batch_size,) + inp1_size), @@ -125,7 +153,13 @@ def forward_func(*x: Tensor) -> Tensor: ) inp[1][:, :, 1] = 4 - attribs = feature_importance.attribute(inp, feature_mask=feature_mask) + + if use_futures: + attribs = feature_importance.attribute_future( + inp, feature_mask=feature_mask + ).wait() + else: + attribs = feature_importance.attribute(inp, feature_mask=feature_mask) self.assertTrue(isinstance(attribs, tuple)) self.assertTrue(len(attribs) == 2) @@ -139,22 +173,33 @@ def forward_func(*x: Tensor) -> Tensor: self.assertTrue((attribs[0] != 0).all()) self.assertTrue((attribs[1][:, :, 0] != 0).all()) - def test_mulitple_perturbations_per_eval(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_mulitple_perturbations_per_eval(self, use_futures) -> None: perturbations_per_eval = 4 batch_size = 2 input_size = (4,) inp = torch.randn((batch_size,) + input_size) - def forward_func(x): + def forward_func(x: Tensor) -> Tensor: return 1 - x target = 1 - feature_importance = FeaturePermutation(forward_func=forward_func) + if use_futures: + feature_importance = FeaturePermutation( + forward_func=self.construct_future_forward(forward_func) + ) + feature_importance.use_futures = use_futures + attribs = feature_importance.attribute_future( + inp, perturbations_per_eval=perturbations_per_eval, target=target + ).wait() + else: + feature_importance = FeaturePermutation(forward_func=forward_func) + + attribs = feature_importance.attribute( + inp, perturbations_per_eval=perturbations_per_eval, target=target + ) - attribs = feature_importance.attribute( - inp, perturbations_per_eval=perturbations_per_eval, target=target - ) self.assertTrue(attribs.size() == (batch_size,) + input_size) for i in range(inp.size(1)): @@ -168,7 +213,8 @@ def forward_func(x): actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]]) assertTensorAlmostEqual(self, attribs[:, target], actual_diff) - def test_broadcastable_masks(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_broadcastable_masks(self, use_futures) -> None: # integration test to ensure that # permutation function works with custom masks def forward_func(x: Tensor) -> Tensor: @@ -176,8 +222,13 @@ def forward_func(x: Tensor) -> Tensor: batch_size = 2 inp = torch.randn((batch_size,) + (3, 4, 4)) - - feature_importance = FeaturePermutation(forward_func=forward_func) + if use_futures: + feature_importance = FeaturePermutation( + forward_func=self.construct_future_forward(forward_func) + ) + feature_importance.use_futures = use_futures + else: + feature_importance = FeaturePermutation(forward_func=forward_func) masks = [ torch.tensor([0]), @@ -186,8 +237,12 @@ def forward_func(x: Tensor) -> Tensor: ] for mask in masks: - attribs = feature_importance.attribute(inp, feature_mask=mask) - + if use_futures: + attribs = feature_importance.attribute_future( + inp, feature_mask=mask + ).wait() + else: + attribs = feature_importance.attribute(inp, feature_mask=mask) self.assertTrue(attribs is not None) self.assertTrue(attribs.shape == inp.shape)