33
44import torch
55from captum .attr ._core .feature_permutation import _permute_feature , FeaturePermutation
6- from parameterized import parameterized
76from tests .helpers import BaseTest
87from tests .helpers .basic import assertTensorAlmostEqual
98from tests .helpers .basic_models import BasicModelWithSparseInputs
@@ -86,39 +85,59 @@ def test_perm_fn_broadcastable_masks(self) -> None:
8685
8786 self ._check_perm_fn_with_mask (inp , mask )
8887
89- @parameterized .expand ([(True ,), (False ,)])
90- def test_single_input (self , use_futures ) -> None :
88+ def test_single_input (self ) -> None :
9189 batch_size = 2
9290 input_size = (6 ,)
9391 constant_value = 10000
9492
9593 def forward_func (x : Tensor ) -> Tensor :
9694 return x .sum (dim = - 1 )
9795
98- if use_futures :
99- feature_importance = FeaturePermutation (
100- forward_func = self .construct_future_forward (forward_func )
101- )
102- feature_importance .use_futures = use_futures
96+ feature_importance = FeaturePermutation (forward_func = forward_func )
97+
98+ inp = torch .randn ((batch_size ,) + input_size )
99+
100+ inp [:, 0 ] = constant_value
101+ zeros = torch .zeros_like (inp [:, 0 ])
102+
103+ attribs = feature_importance .attribute (inp )
104+
105+ self .assertTrue (attribs .squeeze (0 ).size () == (batch_size ,) + input_size )
106+ assertTensorAlmostEqual (self , attribs [:, 0 ], zeros , delta = 0.05 , mode = "max" )
107+ self .assertTrue ((attribs [:, 1 : input_size [0 ]].abs () > 0 ).all ())
103108
104- else :
105- feature_importance = FeaturePermutation (forward_func = forward_func )
109+ def test_single_input_with_future (
110+ self ,
111+ ) -> None :
112+ batch_size = 2
113+ input_size = (6 ,)
114+ constant_value = 10000
115+
116+ def forward_func (x : Tensor ) -> Tensor :
117+ return x .sum (dim = - 1 )
118+
119+ feature_importance = FeaturePermutation (
120+ forward_func = self .construct_future_forward (forward_func )
121+ )
122+ feature_importance .use_futures = True
106123
107124 inp = torch .randn ((batch_size ,) + input_size )
108125
109126 inp [:, 0 ] = constant_value
110127 zeros = torch .zeros_like (inp [:, 0 ])
111- if use_futures :
112- attribs = feature_importance .attribute_future (inp ).wait ()
113- else :
114- attribs = feature_importance .attribute (inp )
128+
129+ attribs = feature_importance .attribute_future (inp )
130+
131+ self .assertTrue (type (attribs ) is torch .Future )
132+ attribs = attribs .wait ()
115133
116134 self .assertTrue (attribs .squeeze (0 ).size () == (batch_size ,) + input_size )
117135 assertTensorAlmostEqual (self , attribs [:, 0 ], zeros , delta = 0.05 , mode = "max" )
118136 self .assertTrue ((attribs [:, 1 : input_size [0 ]].abs () > 0 ).all ())
119137
120- @parameterized .expand ([(True ,), (False ,)])
121- def test_multi_input (self , use_futures ) -> None :
138+ def test_multi_input (
139+ self ,
140+ ) -> None :
122141 batch_size = 20
123142 inp1_size = (5 , 2 )
124143 inp2_size = (5 , 3 )
@@ -133,14 +152,55 @@ def forward_func(*x: Tensor) -> Tensor:
133152
134153 return torch .mean ((y - labels ) ** 2 )
135154
136- if use_futures :
137- feature_importance = FeaturePermutation (
138- forward_func = self .construct_future_forward (forward_func )
139- )
140- feature_importance .use_futures = use_futures
155+ feature_importance = FeaturePermutation (forward_func = forward_func )
141156
142- else :
143- feature_importance = FeaturePermutation (forward_func = forward_func )
157+ inp = (
158+ torch .randn ((batch_size ,) + inp1_size ),
159+ torch .randn ((batch_size ,) + inp2_size ),
160+ )
161+
162+ feature_mask = (
163+ torch .arange (inp [0 ][0 ].numel ()).view_as (inp [0 ][0 ]).unsqueeze (0 ),
164+ torch .arange (inp [1 ][0 ].numel ()).view_as (inp [1 ][0 ]).unsqueeze (0 ),
165+ )
166+
167+ inp [1 ][:, :, 1 ] = 4
168+
169+ attribs = feature_importance .attribute (inp , feature_mask = feature_mask )
170+
171+ self .assertTrue (isinstance (attribs , tuple ))
172+ self .assertTrue (len (attribs ) == 2 )
173+
174+ self .assertTrue (attribs [0 ].squeeze (0 ).size () == inp1_size )
175+ self .assertTrue (attribs [1 ].squeeze (0 ).size () == inp2_size )
176+
177+ self .assertTrue ((attribs [1 ][:, :, 1 ] == 0 ).all ())
178+ self .assertTrue ((attribs [1 ][:, :, 2 ] == 0 ).all ())
179+
180+ self .assertTrue ((attribs [0 ] != 0 ).all ())
181+ self .assertTrue ((attribs [1 ][:, :, 0 ] != 0 ).all ())
182+
183+ def test_multi_input_with_future (
184+ self ,
185+ ) -> None :
186+ batch_size = 20
187+ inp1_size = (5 , 2 )
188+ inp2_size = (5 , 3 )
189+
190+ labels = torch .randn (batch_size )
191+
192+ def forward_func (* x : Tensor ) -> Tensor :
193+ y = torch .zeros (x [0 ].shape [0 :2 ])
194+ for xx in x :
195+ y += xx [:, :, 0 ] * xx [:, :, 1 ]
196+ y = y .sum (dim = - 1 )
197+
198+ return torch .mean ((y - labels ) ** 2 )
199+
200+ feature_importance = FeaturePermutation (
201+ forward_func = self .construct_future_forward (forward_func )
202+ )
203+ feature_importance .use_futures = True
144204
145205 inp = (
146206 torch .randn ((batch_size ,) + inp1_size ),
@@ -154,12 +214,9 @@ def forward_func(*x: Tensor) -> Tensor:
154214
155215 inp [1 ][:, :, 1 ] = 4
156216
157- if use_futures :
158- attribs = feature_importance .attribute_future (
159- inp , feature_mask = feature_mask
160- ).wait ()
161- else :
162- attribs = feature_importance .attribute (inp , feature_mask = feature_mask )
217+ attribs = feature_importance .attribute_future (inp , feature_mask = feature_mask )
218+ self .assertTrue (type (attribs ) is torch .Future )
219+ attribs = attribs .wait ()
163220
164221 self .assertTrue (isinstance (attribs , tuple ))
165222 self .assertTrue (len (attribs ) == 2 )
@@ -173,8 +230,9 @@ def forward_func(*x: Tensor) -> Tensor:
173230 self .assertTrue ((attribs [0 ] != 0 ).all ())
174231 self .assertTrue ((attribs [1 ][:, :, 0 ] != 0 ).all ())
175232
176- @parameterized .expand ([(True ,), (False ,)])
177- def test_mulitple_perturbations_per_eval (self , use_futures ) -> None :
233+ def test_multiple_perturbations_per_eval (
234+ self ,
235+ ) -> None :
178236 perturbations_per_eval = 4
179237 batch_size = 2
180238 input_size = (4 ,)
@@ -185,21 +243,51 @@ def forward_func(x: Tensor) -> Tensor:
185243 return 1 - x
186244
187245 target = 1
188- if use_futures :
189- feature_importance = FeaturePermutation (
190- forward_func = self .construct_future_forward (forward_func )
191- )
192- feature_importance .use_futures = use_futures
193- attribs = feature_importance .attribute_future (
194- inp , perturbations_per_eval = perturbations_per_eval , target = target
195- ).wait ()
196- else :
197- feature_importance = FeaturePermutation (forward_func = forward_func )
198-
199- attribs = feature_importance .attribute (
200- inp , perturbations_per_eval = perturbations_per_eval , target = target
246+
247+ feature_importance = FeaturePermutation (forward_func = forward_func )
248+
249+ attribs = feature_importance .attribute (
250+ inp , perturbations_per_eval = perturbations_per_eval , target = target
251+ )
252+
253+ self .assertTrue (attribs .size () == (batch_size ,) + input_size )
254+
255+ for i in range (inp .size (1 )):
256+ if i == target :
257+ continue
258+ assertTensorAlmostEqual (
259+ self , attribs [:, i ], torch .zeros_like (attribs [:, i ])
201260 )
202261
262+ y = forward_func (inp )
263+ actual_diff = torch .stack ([(y [0 ] - y [1 ])[target ], (y [1 ] - y [0 ])[target ]])
264+ assertTensorAlmostEqual (self , attribs [:, target ], actual_diff )
265+
266+ def test_multiple_perturbations_per_eval_with_futures (
267+ self ,
268+ ) -> None :
269+ perturbations_per_eval = 4
270+ batch_size = 2
271+ input_size = (4 ,)
272+
273+ inp = torch .randn ((batch_size ,) + input_size )
274+
275+ def forward_func (x : Tensor ) -> Tensor :
276+ return 1 - x
277+
278+ target = 1
279+
280+ feature_importance = FeaturePermutation (
281+ forward_func = self .construct_future_forward (forward_func )
282+ )
283+ feature_importance .use_futures = True
284+
285+ attribs = feature_importance .attribute_future (
286+ inp , perturbations_per_eval = perturbations_per_eval , target = target
287+ )
288+ self .assertTrue (type (attribs ) is torch .Future )
289+ attribs = attribs .wait ()
290+
203291 self .assertTrue (attribs .size () == (batch_size ,) + input_size )
204292
205293 for i in range (inp .size (1 )):
@@ -213,22 +301,18 @@ def forward_func(x: Tensor) -> Tensor:
213301 actual_diff = torch .stack ([(y [0 ] - y [1 ])[target ], (y [1 ] - y [0 ])[target ]])
214302 assertTensorAlmostEqual (self , attribs [:, target ], actual_diff )
215303
216- @parameterized .expand ([(True ,), (False ,)])
217- def test_broadcastable_masks (self , use_futures ) -> None :
304+ def test_broadcastable_masks (
305+ self ,
306+ ) -> None :
218307 # integration test to ensure that
219308 # permutation function works with custom masks
220309 def forward_func (x : Tensor ) -> Tensor :
221310 return x .view (x .shape [0 ], - 1 ).sum (dim = - 1 )
222311
223312 batch_size = 2
224313 inp = torch .randn ((batch_size ,) + (3 , 4 , 4 ))
225- if use_futures :
226- feature_importance = FeaturePermutation (
227- forward_func = self .construct_future_forward (forward_func )
228- )
229- feature_importance .use_futures = use_futures
230- else :
231- feature_importance = FeaturePermutation (forward_func = forward_func )
314+
315+ feature_importance = FeaturePermutation (forward_func = forward_func )
232316
233317 masks = [
234318 torch .tensor ([0 ]),
@@ -237,12 +321,8 @@ def forward_func(x: Tensor) -> Tensor:
237321 ]
238322
239323 for mask in masks :
240- if use_futures :
241- attribs = feature_importance .attribute_future (
242- inp , feature_mask = mask
243- ).wait ()
244- else :
245- attribs = feature_importance .attribute (inp , feature_mask = mask )
324+
325+ attribs = feature_importance .attribute (inp , feature_mask = mask )
246326 self .assertTrue (attribs is not None )
247327 self .assertTrue (attribs .shape == inp .shape )
248328
@@ -260,6 +340,54 @@ def forward_func(x: Tensor) -> Tensor:
260340 mode = "max" ,
261341 )
262342
343+ def test_broadcastable_masks_with_future (
344+ self ,
345+ ) -> None :
346+ # integration test to ensure that
347+ # permutation function works with custom masks
348+ def forward_func (x : Tensor ) -> Tensor :
349+ return x .view (x .shape [0 ], - 1 ).sum (dim = - 1 )
350+
351+ batch_size = 2
352+ inp = torch .randn ((batch_size ,) + (3 , 4 , 4 ))
353+
354+ feature_importance = FeaturePermutation (
355+ forward_func = self .construct_future_forward (forward_func )
356+ )
357+ feature_importance .use_futures = True
358+
359+ masks = [
360+ torch .tensor ([0 ]),
361+ torch .tensor ([[0 , 1 , 2 , 3 ]]),
362+ torch .tensor ([[[0 , 1 , 2 , 3 ], [3 , 3 , 4 , 5 ], [6 , 6 , 4 , 6 ], [7 , 8 , 9 , 10 ]]]),
363+ ]
364+
365+ results = []
366+
367+ for mask in masks :
368+ attribs_future = feature_importance .attribute_future (inp , feature_mask = mask )
369+ results .append (attribs_future )
370+ self .assertTrue (attribs_future is not None )
371+
372+ for idx in range (len (results )):
373+ attribs = results [idx ].wait ()
374+ self .assertTrue (attribs is not None )
375+ self .assertTrue (attribs .shape == inp .shape )
376+
377+ fm = masks [idx ].expand_as (inp [0 ])
378+
379+ features = set (masks [idx ].flatten ())
380+ for feature in features :
381+ m = (fm == feature ).bool ()
382+ attribs_for_feature = attribs [:, m ]
383+ assertTensorAlmostEqual (
384+ self ,
385+ attribs_for_feature [0 ],
386+ - attribs_for_feature [1 ],
387+ delta = 0.05 ,
388+ mode = "max" ,
389+ )
390+
263391 def test_empty_sparse_features (self ) -> None :
264392 model = BasicModelWithSparseInputs ()
265393 inp1 = torch .tensor ([[1.0 , - 2.0 , 3.0 ], [2.0 , - 1.0 , 3.0 ]])
0 commit comments