11#!/usr/bin/env python3
2- from typing import List , Tuple
2+ from typing import Callable , List , Tuple
33
44import torch
55from captum .attr ._core .feature_permutation import _permute_feature , FeaturePermutation
6+ from parameterized import parameterized
67from tests .helpers import BaseTest
78from tests .helpers .basic import assertTensorAlmostEqual
89from tests .helpers .basic_models import BasicModelWithSparseInputs
910from torch import Tensor
1011
1112
13+ # pyre-ignore Undefined attribute [13]
1214class Test (BaseTest ):
15+ def construct_future_forward (self , original_forward : Callable ) -> Callable :
16+ def future_forward (* args , ** kwargs ):
17+ fut = torch .futures .Future ()
18+ fut .set_result (original_forward (* args , ** kwargs ))
19+ return fut
20+
21+ return future_forward
22+
1323 def _check_features_are_permuted (
1424 self , inp : Tensor , perm_inp : Tensor , mask : Tensor
1525 ) -> None :
@@ -76,28 +86,39 @@ def test_perm_fn_broadcastable_masks(self) -> None:
7686
7787 self ._check_perm_fn_with_mask (inp , mask )
7888
79- def test_single_input (self ) -> None :
89+ @parameterized .expand ([(True ,), (False ,)])
90+ def test_single_input (self , use_futures ) -> None :
8091 batch_size = 2
8192 input_size = (6 ,)
8293 constant_value = 10000
8394
8495 def forward_func (x : Tensor ) -> Tensor :
8596 return x .sum (dim = - 1 )
8697
87- feature_importance = FeaturePermutation (forward_func = forward_func )
98+ if use_futures :
99+ feature_importance = FeaturePermutation (
100+ forward_func = self .construct_future_forward (forward_func )
101+ )
102+ feature_importance .use_futures = use_futures
103+
104+ else :
105+ feature_importance = FeaturePermutation (forward_func = forward_func )
88106
89107 inp = torch .randn ((batch_size ,) + input_size )
90108
91109 inp [:, 0 ] = constant_value
92110 zeros = torch .zeros_like (inp [:, 0 ])
93-
94- attribs = feature_importance .attribute (inp )
111+ if use_futures :
112+ attribs = feature_importance .attribute_future (inp ).wait ()
113+ else :
114+ attribs = feature_importance .attribute (inp )
95115
96116 self .assertTrue (attribs .squeeze (0 ).size () == (batch_size ,) + input_size )
97117 assertTensorAlmostEqual (self , attribs [:, 0 ], zeros , delta = 0.05 , mode = "max" )
98118 self .assertTrue ((attribs [:, 1 : input_size [0 ]].abs () > 0 ).all ())
99119
100- def test_multi_input (self ) -> None :
120+ @parameterized .expand ([(True ,), (False ,)])
121+ def test_multi_input (self , use_futures ) -> None :
101122 batch_size = 20
102123 inp1_size = (5 , 2 )
103124 inp2_size = (5 , 3 )
@@ -112,7 +133,14 @@ def forward_func(*x: Tensor) -> Tensor:
112133
113134 return torch .mean ((y - labels ) ** 2 )
114135
115- feature_importance = FeaturePermutation (forward_func = forward_func )
136+ if use_futures :
137+ feature_importance = FeaturePermutation (
138+ forward_func = self .construct_future_forward (forward_func )
139+ )
140+ feature_importance .use_futures = use_futures
141+
142+ else :
143+ feature_importance = FeaturePermutation (forward_func = forward_func )
116144
117145 inp = (
118146 torch .randn ((batch_size ,) + inp1_size ),
@@ -125,7 +153,13 @@ def forward_func(*x: Tensor) -> Tensor:
125153 )
126154
127155 inp [1 ][:, :, 1 ] = 4
128- attribs = feature_importance .attribute (inp , feature_mask = feature_mask )
156+
157+ if use_futures :
158+ attribs = feature_importance .attribute_future (
159+ inp , feature_mask = feature_mask
160+ ).wait ()
161+ else :
162+ attribs = feature_importance .attribute (inp , feature_mask = feature_mask )
129163
130164 self .assertTrue (isinstance (attribs , tuple ))
131165 self .assertTrue (len (attribs ) == 2 )
@@ -139,22 +173,33 @@ def forward_func(*x: Tensor) -> Tensor:
139173 self .assertTrue ((attribs [0 ] != 0 ).all ())
140174 self .assertTrue ((attribs [1 ][:, :, 0 ] != 0 ).all ())
141175
142- def test_mulitple_perturbations_per_eval (self ) -> None :
176+ @parameterized .expand ([(True ,), (False ,)])
177+ def test_mulitple_perturbations_per_eval (self , use_futures ) -> None :
143178 perturbations_per_eval = 4
144179 batch_size = 2
145180 input_size = (4 ,)
146181
147182 inp = torch .randn ((batch_size ,) + input_size )
148183
149- def forward_func (x ) :
184+ def forward_func (x : Tensor ) -> Tensor :
150185 return 1 - x
151186
152187 target = 1
153- feature_importance = FeaturePermutation (forward_func = forward_func )
188+ if use_futures :
189+ feature_importance = FeaturePermutation (
190+ forward_func = self .construct_future_forward (forward_func )
191+ )
192+ feature_importance .use_futures = use_futures
193+ attribs = feature_importance .attribute_future (
194+ inp , perturbations_per_eval = perturbations_per_eval , target = target
195+ ).wait ()
196+ else :
197+ feature_importance = FeaturePermutation (forward_func = forward_func )
198+
199+ attribs = feature_importance .attribute (
200+ inp , perturbations_per_eval = perturbations_per_eval , target = target
201+ )
154202
155- attribs = feature_importance .attribute (
156- inp , perturbations_per_eval = perturbations_per_eval , target = target
157- )
158203 self .assertTrue (attribs .size () == (batch_size ,) + input_size )
159204
160205 for i in range (inp .size (1 )):
@@ -168,16 +213,22 @@ def forward_func(x):
168213 actual_diff = torch .stack ([(y [0 ] - y [1 ])[target ], (y [1 ] - y [0 ])[target ]])
169214 assertTensorAlmostEqual (self , attribs [:, target ], actual_diff )
170215
171- def test_broadcastable_masks (self ) -> None :
216+ @parameterized .expand ([(True ,), (False ,)])
217+ def test_broadcastable_masks (self , use_futures ) -> None :
172218 # integration test to ensure that
173219 # permutation function works with custom masks
174220 def forward_func (x : Tensor ) -> Tensor :
175221 return x .view (x .shape [0 ], - 1 ).sum (dim = - 1 )
176222
177223 batch_size = 2
178224 inp = torch .randn ((batch_size ,) + (3 , 4 , 4 ))
179-
180- feature_importance = FeaturePermutation (forward_func = forward_func )
225+ if use_futures :
226+ feature_importance = FeaturePermutation (
227+ forward_func = self .construct_future_forward (forward_func )
228+ )
229+ feature_importance .use_futures = use_futures
230+ else :
231+ feature_importance = FeaturePermutation (forward_func = forward_func )
181232
182233 masks = [
183234 torch .tensor ([0 ]),
@@ -186,8 +237,12 @@ def forward_func(x: Tensor) -> Tensor:
186237 ]
187238
188239 for mask in masks :
189- attribs = feature_importance .attribute (inp , feature_mask = mask )
190-
240+ if use_futures :
241+ attribs = feature_importance .attribute_future (
242+ inp , feature_mask = mask
243+ ).wait ()
244+ else :
245+ attribs = feature_importance .attribute (inp , feature_mask = mask )
191246 self .assertTrue (attribs is not None )
192247 self .assertTrue (attribs .shape == inp .shape )
193248
0 commit comments