Skip to content

Commit 3188421

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Add enable_cross_tensor_attribution tests to test_config (#1511)
Summary: Pull Request resolved: #1511 TSIA Reviewed By: cyrjano Differential Revision: D69957243 fbshipit-source-id: 70e0a7ce01ef9c3359f90473137a075f27aad139
1 parent 6f0f748 commit 3188421

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

captum/testing/attr/helpers/test_config.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,19 @@
112112
"model": BasicModel_MultiLayer(),
113113
"attribute_args": {"inputs": torch.randn(4, 3), "target": 1},
114114
},
115+
{
116+
"name": "basic_single_target_cross_tensor_attributions",
117+
"algorithms": [
118+
FeatureAblation,
119+
FeaturePermutation,
120+
],
121+
"model": BasicModel_MultiLayer(),
122+
"attribute_args": {
123+
"inputs": torch.randn(4, 3),
124+
"target": 1,
125+
"enable_cross_tensor_attribution": True,
126+
},
127+
},
115128
{
116129
"name": "basic_multi_input",
117130
"algorithms": [
@@ -179,6 +192,21 @@
179192
},
180193
"dp_delta": 0.0005,
181194
},
195+
{
196+
"name": "basic_multi_input_multi_target_cross_tensor_attributions",
197+
"algorithms": [
198+
FeatureAblation,
199+
FeaturePermutation,
200+
],
201+
"model": BasicModel_MultiLayer_MultiInput(),
202+
"attribute_args": {
203+
"inputs": (10 * torch.randn(6, 3), 5 * torch.randn(6, 3)),
204+
"additional_forward_args": (2 * torch.randn(6, 3), 5),
205+
"target": [0, 1, 1, 0, 0, 1],
206+
"enable_cross_tensor_attribution": True,
207+
},
208+
"dp_delta": 0.0005,
209+
},
182210
{
183211
"name": "basic_multiple_tuple_target",
184212
"algorithms": [
@@ -202,6 +230,20 @@
202230
"additional_forward_args": (None, True),
203231
},
204232
},
233+
{
234+
"name": "basic_multiple_tuple_target_cross_tensor_attributions",
235+
"algorithms": [
236+
FeatureAblation,
237+
FeaturePermutation,
238+
],
239+
"model": BasicModel_MultiLayer(),
240+
"attribute_args": {
241+
"inputs": torch.randn(4, 3),
242+
"target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)],
243+
"additional_forward_args": (None, True),
244+
"enable_cross_tensor_attribution": True,
245+
},
246+
},
205247
{
206248
"name": "basic_tensor_single_target",
207249
"algorithms": [
@@ -243,6 +285,19 @@
243285
"target": torch.tensor([1, 1, 0, 0]),
244286
},
245287
},
288+
{
289+
"name": "basic_tensor_multi_target_cross_tensor_attributions",
290+
"algorithms": [
291+
FeatureAblation,
292+
FeaturePermutation,
293+
],
294+
"model": BasicModel_MultiLayer(),
295+
"attribute_args": {
296+
"inputs": torch.randn(4, 3),
297+
"target": torch.tensor([1, 1, 0, 0]),
298+
"enable_cross_tensor_attribution": True,
299+
},
300+
},
246301
# Primary Configs with Baselines
247302
{
248303
"name": "basic_multiple_tuple_target_with_baselines",
@@ -262,6 +317,20 @@
262317
"additional_forward_args": (None, True),
263318
},
264319
},
320+
{
321+
"name": "basic_multiple_tuple_target_with_baselines_cross_tensor_attributions",
322+
"algorithms": [
323+
FeatureAblation,
324+
],
325+
"model": BasicModel_MultiLayer(),
326+
"attribute_args": {
327+
"inputs": torch.randn(4, 3),
328+
"baselines": 0.5 * torch.randn(4, 3),
329+
"target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)],
330+
"additional_forward_args": (None, True),
331+
"enable_cross_tensor_attribution": True,
332+
},
333+
},
265334
{
266335
"name": "basic_tensor_single_target_with_baselines",
267336
"algorithms": [
@@ -279,6 +348,19 @@
279348
"target": torch.tensor([0]),
280349
},
281350
},
351+
{
352+
"name": "basic_tensor_single_target_with_baselines_cross_tensor_attributions",
353+
"algorithms": [
354+
FeatureAblation,
355+
],
356+
"model": BasicModel_MultiLayer(),
357+
"attribute_args": {
358+
"inputs": torch.randn(4, 3),
359+
"baselines": 0.5 * torch.randn(4, 3),
360+
"target": torch.tensor([0]),
361+
"enable_cross_tensor_attribution": True,
362+
},
363+
},
282364
# Primary Configs with Internal Batching
283365
{
284366
"name": "basic_multiple_tuple_target_with_internal_batching",

0 commit comments

Comments
 (0)